mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): Add the HLFHELinalg.matmul_int_eint operator
This commit is contained in:
@@ -402,8 +402,47 @@ def MatMulEintIntOp : HLFHELinalg_Op<"matmul_eint_int", [TensorBinaryEintInt]> {
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::zamalang::HLFHELinalg::verifyMatmul(*this);
|
||||
return ::mlir::zamalang::HLFHELinalg::verifyMatmul<mlir::zamalang::HLFHELinalg::MatMulEintIntOp>(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> {
|
||||
let summary = "Returns a tensor that contains the result of the matrix multiplication of a matrix of clear integers and a matrix of encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
Performs a matrix multiplication of a matrix of clear integers and a matrix of encrypted integers.
|
||||
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
|
||||
|
||||
```mlir
|
||||
"HLFHELinalg.matmul_int_eint(%a, %b) : (tensor<MxNxip'>, tensor<NxPxHLFHE.eint<p>>) -> tensor<MxPx!HLFHE.eint<p>>"
|
||||
```
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the matrix multiplication of a 3x2 matrix of clear integers and a 2x3 matrix of encrypted integers.
|
||||
// [ 1, 2, 3]
|
||||
// [ 2, 3, 4]
|
||||
// *
|
||||
// [1,2] [ 5, 8,11]
|
||||
// [3,4] = [11,18,25]
|
||||
// [5,6] [17,28,39]
|
||||
//
|
||||
"HLFHELinalg.matmul_int_eint"(%a, %b) : (tensor<3x2xi7>, tensor<2x3x!HLFHE.eint<6>>) -> tensor<3x3x!HLFHE.eint<6>>
|
||||
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$lhs,
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::zamalang::HLFHELinalg::verifyMatmul<mlir::zamalang::HLFHELinalg::MatMulIntEintOp>(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -598,11 +598,12 @@ struct HLFHELinalgNegEintToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalg.matmul_eint_int` to an instance of `linalg.generic`
|
||||
// with an appropriate region using `HLFHE.mul_eint_int` and `HLFHE.add_eint`
|
||||
// operation, an appropriate specification for the iteration dimensions and
|
||||
// appropriate operations managing the accumulator of `linalg.generic`.
|
||||
// This template rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalgMatmulOp` to an instance of `linalg.generic`
|
||||
// with an appropriate region using a builder that create the multiplication
|
||||
// operators and `HLFHE.add_eint` operation, an appropriate specification for
|
||||
// the iteration dimensions and appropriate operations managing the accumulator
|
||||
// of `linalg.generic`.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
@@ -633,27 +634,33 @@ struct HLFHELinalgNegEintToLinalgGeneric
|
||||
// outs(%C : tensor<MxNx!HLFHE.eint<p>>)
|
||||
// {
|
||||
// ^bb0(%a: !HLFHE.eint<p>, %b: ip', %c: !HLFHE.eint<p>) :
|
||||
// %d = "HLFHE.mul_eint_int"(%a, %b) :
|
||||
// (!HLFHE.eint<p>, ip') -> !HLFHE.eint<p>
|
||||
// %d = createMulOp(%a, %b): !HLFHE.eint<p>
|
||||
// %e = "HLFHE.add_eint"(%c, %d):
|
||||
// (!HLFHE.eint<p>, !HLFHE.eint<p>) -> !HLFHE.eint<p>
|
||||
// linalg.yield %e : !HLFHE.eint<p>
|
||||
// }
|
||||
//
|
||||
struct HLFHELinalgMatmulEintIntToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<
|
||||
mlir::zamalang::HLFHELinalg::MatMulEintIntOp> {
|
||||
HLFHELinalgMatmulEintIntToLinalgGeneric(::mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::zamalang::HLFHELinalg::MatMulEintIntOp>(
|
||||
context, benefit) {}
|
||||
template <typename HLFHELinalgMatmulOp>
|
||||
struct HLFHELinalgMatmulToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<HLFHELinalgMatmulOp> {
|
||||
HLFHELinalgMatmulToLinalgGeneric(
|
||||
mlir::MLIRContext *context,
|
||||
std::function<mlir::zamalang::HLFHE::MulEintIntOp(
|
||||
mlir::OpBuilder &, mlir::Location, mlir::Type, mlir::Value,
|
||||
mlir::Value)>
|
||||
createMulOp,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<HLFHELinalgMatmulOp>(context, benefit),
|
||||
createMulOp(createMulOp) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::zamalang::HLFHELinalg::MatMulEintIntOp matmulOp,
|
||||
matchAndRewrite(HLFHELinalgMatmulOp matmulOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location matmulLoc = matmulOp.getLoc();
|
||||
mlir::RankedTensorType resultTy =
|
||||
((mlir::Type)matmulOp->getResult(0).getType())
|
||||
.cast<mlir::RankedTensorType>();
|
||||
mlir::Type resultElementTy = resultTy.getElementType();
|
||||
// Create tensor.generate for initial value
|
||||
auto generateBody = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
@@ -661,17 +668,13 @@ struct HLFHELinalgMatmulEintIntToLinalgGeneric
|
||||
// %z = "HLFHE.zero" : () -> !HLFHE.eint<2>
|
||||
mlir::zamalang::HLFHE::ZeroEintOp zeroOp =
|
||||
nestedBuilder.create<mlir::zamalang::HLFHE::ZeroEintOp>(
|
||||
matmulOp.getLoc(), resultTy.getElementType());
|
||||
matmulLoc, resultElementTy);
|
||||
// linalg.yield %z : !HLFHE.eint<p>
|
||||
nestedBuilder.create<mlir::tensor::YieldOp>(matmulOp.getLoc(),
|
||||
nestedBuilder.create<mlir::tensor::YieldOp>(matmulLoc,
|
||||
zeroOp.getResult());
|
||||
};
|
||||
mlir::tensor::GenerateOp init = rewriter.create<mlir::tensor::GenerateOp>(
|
||||
matmulOp.getLoc(), (mlir::Type)resultTy, mlir::ValueRange{},
|
||||
generateBody);
|
||||
// linalg.init_tensor for initial value
|
||||
// mlir::Value init = rewriter.create<mlir::linalg::InitTensorOp>(
|
||||
// matmulOp.getLoc(), resultTy.getShape(), resultTy.getElementType());
|
||||
matmulLoc, (mlir::Type)resultTy, mlir::ValueRange{}, generateBody);
|
||||
// Create the affine #maps_0
|
||||
llvm::SmallVector<mlir::AffineMap> maps{
|
||||
// (m, n, p) -> (m, p),
|
||||
@@ -698,17 +701,15 @@ struct HLFHELinalgMatmulEintIntToLinalgGeneric
|
||||
mlir::ValueRange blockArgs) {
|
||||
// "HLFHE.mul_eint_int"(%a, %b) : (!HLFHE.eint<p>, ip') -> !HLFHE.eint<p>
|
||||
mlir::zamalang::HLFHE::MulEintIntOp mulEintIntOp =
|
||||
nestedBuilder.create<mlir::zamalang::HLFHE::MulEintIntOp>(
|
||||
matmulOp.getLoc(), resultTy.getElementType(), blockArgs[0],
|
||||
blockArgs[1]);
|
||||
createMulOp(nestedBuilder, matmulLoc, resultElementTy, blockArgs[0],
|
||||
blockArgs[1]);
|
||||
// "HLFHE.add_eint"(%c, %d): (!HLFHE.eint<p>, !HLFHE.eint<p>) ->
|
||||
// !HLFHE.eint<p>
|
||||
mlir::zamalang::HLFHE::AddEintOp addEintOp =
|
||||
nestedBuilder.create<mlir::zamalang::HLFHE::AddEintOp>(
|
||||
matmulOp.getLoc(), resultTy.getElementType(), blockArgs[2],
|
||||
mulEintIntOp);
|
||||
matmulLoc, resultElementTy, blockArgs[2], mulEintIntOp);
|
||||
// linalg.yield %e : !HLFHE.eint<p>
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(matmulOp.getLoc(),
|
||||
nestedBuilder.create<mlir::linalg::YieldOp>(matmulLoc,
|
||||
addEintOp.getResult());
|
||||
};
|
||||
|
||||
@@ -720,14 +721,19 @@ struct HLFHELinalgMatmulEintIntToLinalgGeneric
|
||||
llvm::StringRef call{""};
|
||||
|
||||
mlir::linalg::GenericOp genericOp =
|
||||
rewriter.create<mlir::linalg::GenericOp>(matmulOp.getLoc(), resTypes,
|
||||
ins, outs, maps, iteratorTypes,
|
||||
doc, call, bodyBuilder);
|
||||
rewriter.create<mlir::linalg::GenericOp>(matmulLoc, resTypes, ins, outs,
|
||||
maps, iteratorTypes, doc, call,
|
||||
bodyBuilder);
|
||||
|
||||
rewriter.replaceOp(matmulOp, {genericOp.getResult(0)});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
std::function<mlir::zamalang::HLFHE::MulEintIntOp(
|
||||
mlir::OpBuilder &, mlir::Location, mlir::Type, mlir::Value, mlir::Value)>
|
||||
createMulOp;
|
||||
};
|
||||
|
||||
namespace {
|
||||
@@ -771,7 +777,20 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
|
||||
&getContext());
|
||||
patterns.insert<HLFHELinalgApplyLookupTableToLinalgGeneric>(&getContext());
|
||||
patterns.insert<HLFHELinalgNegEintToLinalgGeneric>(&getContext());
|
||||
patterns.insert<HLFHELinalgMatmulEintIntToLinalgGeneric>(&getContext());
|
||||
patterns.insert<HLFHELinalgMatmulToLinalgGeneric<
|
||||
mlir::zamalang::HLFHELinalg::MatMulEintIntOp>>(
|
||||
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
|
||||
return builder.create<mlir::zamalang::HLFHE::MulEintIntOp>(loc, type,
|
||||
arg0, arg1);
|
||||
});
|
||||
patterns.insert<HLFHELinalgMatmulToLinalgGeneric<
|
||||
mlir::zamalang::HLFHELinalg::MatMulIntEintOp>>(
|
||||
&getContext(), [](mlir::OpBuilder &builder, mlir::Location loc,
|
||||
mlir::Type type, mlir::Value arg0, mlir::Value arg1) {
|
||||
return builder.create<mlir::zamalang::HLFHE::MulEintIntOp>(loc, type,
|
||||
arg1, arg0);
|
||||
});
|
||||
patterns.insert<HLFHELinalgApplyMultiLookupTableToLinalgGeneric>(
|
||||
&getContext());
|
||||
|
||||
|
||||
@@ -686,6 +686,71 @@ static llvm::APInt getSqMANP(
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHELinalg::MatMulIntEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
mlir::RankedTensorType rhsTy =
|
||||
op.rhs().getType().cast<mlir::RankedTensorType>();
|
||||
mlir::RankedTensorType lhsTy =
|
||||
op.lhs().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Type iTy = lhsTy.getElementType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
"Only multiplications with signless integers are currently allowed");
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt rhsNorm = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
// Initial value of the accumulator
|
||||
llvm::APInt accNorm = llvm::APInt{1, 1, false};
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
op->getOpOperand(0).get().getDefiningOp());
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
cstOp ? cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
|
||||
: nullptr;
|
||||
|
||||
if (denseVals) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
// tensor<MxN> = tensor<MxP> * tensor<PxN> compute the max 2-norm of the
|
||||
// result
|
||||
int64_t M = lhsTy.getShape()[0];
|
||||
int64_t N = rhsTy.getShape()[1];
|
||||
int64_t P = rhsTy.getShape()[0];
|
||||
for (int64_t m = 0; m < M; m++) {
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
llvm::APInt tmpNorm = llvm::APInt{1, 1, false};
|
||||
for (int64_t p = 0; p < P; p++) {
|
||||
llvm::APInt cst = denseVals.getFlatValue<llvm::APInt>(m * P + p);
|
||||
llvm::APInt lhsNorm = APIntWidthExtendUSq(cst);
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
|
||||
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
|
||||
}
|
||||
accNorm = APIntUMax(accNorm, tmpNorm);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
llvm::APInt lhsNorm = conservativeIntNorm2Sq(iTy);
|
||||
// For tensor<MxN> = tensor<MxP> * tensor<PxN> they are P HLFHE.mul_eint_int
|
||||
// and HLFHE.add_eint operations for each elements of the result
|
||||
int64_t P = rhsTy.getShape()[0];
|
||||
for (int64_t i = 0; i < P; i++) {
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(rhsNorm, lhsNorm);
|
||||
accNorm = APIntWidthExtendUAdd(mulNorm, accNorm);
|
||||
}
|
||||
}
|
||||
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::tensor::ExtractOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
@@ -823,6 +888,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::MatMulEintIntOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(matmulEintIntOp, operands);
|
||||
} else if (auto matmulIntEintOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHELinalg::MatMulIntEintOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(matmulIntEintOp, operands);
|
||||
} else if (llvm::isa<
|
||||
mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp,
|
||||
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp>(
|
||||
|
||||
@@ -309,14 +309,15 @@ verifyApplyMultiLookupTable(ApplyMultiLookupTableEintOp &op) {
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
/// Verify the matmul shapes, the type of tensor elements are checked by
|
||||
/// TensorBinaryEintInt
|
||||
mlir::LogicalResult verifyMatmul(MatMulEintIntOp &op) {
|
||||
auto lhsTy = op.lhs().getType().cast<mlir::RankedTensorType>();
|
||||
/// Verify the matmul shapes, the type of tensor elements should be checked by
|
||||
/// something else
|
||||
template <typename MatMulOp> mlir::LogicalResult verifyMatmul(MatMulOp &op) {
|
||||
auto lhsTy = ((mlir::Type)op.lhs().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
auto rhsTy = op.rhs().getType().cast<mlir::RankedTensorType>();
|
||||
auto rhsTy = ((mlir::Type)op.rhs().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
auto resultTy = op.getResult().getType().cast<mlir::RankedTensorType>();
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
|
||||
if (lhsTy.getShape().size() != 2 || rhsTy.getShape().size() != 2) {
|
||||
op.emitOpError() << "should have 2D tensors as operands";
|
||||
@@ -333,9 +334,8 @@ mlir::LogicalResult verifyMatmul(MatMulEintIntOp &op) {
|
||||
rhsTy.getDimSize(1)};
|
||||
if (!resultTy.hasStaticShape(expectedShape)) {
|
||||
op.emitOpError() << "should have the result shape compatible with operands "
|
||||
"shape, expect "
|
||||
<< expectedShape[0] << "x" << expectedShape[1]
|
||||
<< " as the shape of the result";
|
||||
<< "shape, expect " << expectedShape[0] << "x"
|
||||
<< expectedShape[1] << " as the shape of the result";
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
|
||||
@@ -137,6 +137,32 @@ func @apply_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>
|
||||
|
||||
// -----
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.apply_multi_lookup_table
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
func @apply_multi_lookup_table(%t: tensor<3x3x!HLFHE.eint<2>>, %luts: tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>> {
|
||||
// CHECK: %[[RES:.*]] = "HLFHELinalg.apply_multi_lookup_table"(%[[T:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<3x3x!HLFHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>>
|
||||
%res = "HLFHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x3x!HLFHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<3>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @apply_multi_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>, %luts: tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>> {
|
||||
// CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: %[[RES:.*]] = "HLFHELinalg.apply_multi_lookup_table"(%[[V0:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!HLFHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>>
|
||||
%res = "HLFHELinalg.apply_multi_lookup_table"(%0, %luts) : (tensor<8x!HLFHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>>
|
||||
return %res : tensor<8x!HLFHE.eint<3>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.matmul_ent_int
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!HLFHE.eint<2>>, %arg1: tensor<1x2xi3>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
@@ -214,20 +240,85 @@ func @matmul_eint_int_cst_p_2_n_1(%arg0: tensor<3x2x!HLFHE.eint<2>>) -> tensor<3
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.matmul_int_eint
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// -----
|
||||
|
||||
func @apply_multi_lookup_table(%t: tensor<3x3x!HLFHE.eint<2>>, %luts: tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>> {
|
||||
// CHECK: %[[RES:.*]] = "HLFHELinalg.apply_multi_lookup_table"(%[[T:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<3x3x!HLFHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>>
|
||||
%res = "HLFHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x3x!HLFHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!HLFHE.eint<3>>
|
||||
return %res : tensor<3x3x!HLFHE.eint<3>>
|
||||
func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 65
|
||||
// ceil(sqrt(65)) = 9
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @apply_multi_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor<8xi3>, %luts: tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>> {
|
||||
// CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
%0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!HLFHE.eint<2>>, tensor<8xi3>) -> tensor<8x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: %[[RES:.*]] = "HLFHELinalg.apply_multi_lookup_table"(%[[V0:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!HLFHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>>
|
||||
%res = "HLFHELinalg.apply_multi_lookup_table"(%0, %luts) : (tensor<8x!HLFHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!HLFHE.eint<3>>
|
||||
return %res : tensor<8x!HLFHE.eint<3>>
|
||||
}
|
||||
func @matmul_int_eint_dyn_p_2(%arg0: tensor<3x2xi3>, %arg1: tensor<2x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 65
|
||||
// p = 1
|
||||
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 65 = 129
|
||||
// ceil(sqrt(129)) = 12
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 12 : ui{{[0-9]+}}}
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x2xi3>, tensor<2x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_int_eint_cst_p_1(%arg0: tensor<1x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>> {
|
||||
%0 = arith.constant dense<[[3], [1]]> : tensor<2x1xi3>
|
||||
// c(m,n) = a(m,p) * b(p,n) the max cst is used for m = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// ceil(sqrt(10)) = 4
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x1xi3>, tensor<1x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>>
|
||||
return %1 : tensor<2x3x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_int_eint_cst_p_2_n_0(%arg0: tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>> {
|
||||
%0 = arith.constant dense<[[3, 4],[1, 1]]> : tensor<2x2xi3>
|
||||
// c(m,n) = a(m,p) * b(p,n) the max csts [4,3] are used for m = 0
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// p = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 17
|
||||
// manp(add_eint(mul, acc)) = 17 + 9 = 26
|
||||
// ceil(sqrt(26)) = 6
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 6 : ui{{[0-9]+}}}
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x2xi3>, tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>>
|
||||
return %1 : tensor<2x3x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_int_eint_cst_p_2_n_1(%arg0: tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>> {
|
||||
%0 = arith.constant dense<[[4, 1],[3, 1]]> : tensor<2x2xi3>
|
||||
// c(m,n) = a(m,p) * b(p,n) the max csts [4,1] are used for m = 1
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 16
|
||||
// manp(add_eint(mul, acc)) = 16 + 1 = 17
|
||||
// p = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, 1) = 1 * 1^2 = 1
|
||||
// manp(add_eint(mul, acc)) = 1 + 17 = 18
|
||||
// ceil(sqrt(18)) = 5
|
||||
// CHECK: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}}
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x2xi3>, tensor<2x3x!HLFHE.eint<2>>) -> tensor<2x3x!HLFHE.eint<2>>
|
||||
return %1 : tensor<2x3x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
@@ -194,4 +194,38 @@ func @matmul_eint_int(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>)
|
||||
return %1 : tensor<4x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.matmul_int_eint
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
func @matmul_int_eint(%arg0: tensor<2x3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.matmul_int_eint' op should have 2D tensors as operands}}
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<2x3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<2x4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.matmul_int_eint' op should have 2D tensors as operands}}
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<2x4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<5x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.matmul_int_eint' op should have the dimension #0 of operand #1equals to the dimension #1 of operand #0, expect 4 got 5}}
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<5x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>) -> tensor<4x2x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.matmul_int_eint' op should have the result shape compatible with operands shape, expect 3x2 as the shape of the result}}
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<4x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<4x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
@@ -316,3 +316,16 @@ func @matmul_eint_int(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4x2xi3>)
|
||||
%1 = "HLFHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x4x!HLFHE.eint<2>>, tensor<4x2xi3>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.matmul_int_eint
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// CHECK-LABEL: @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
func @matmul_int_eint(%arg0: tensor<3x4xi3>, %arg1: tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1) : (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<3x2x!HLFHE.eint<2>>
|
||||
|
||||
%1 = "HLFHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x4xi3>, tensor<4x2x!HLFHE.eint<2>>) -> tensor<3x2x!HLFHE.eint<2>>
|
||||
return %1 : tensor<3x2x!HLFHE.eint<2>>
|
||||
}
|
||||
@@ -1132,8 +1132,7 @@ TEST(End2EndJit_HLFHELinalg, apply_multi_lookup_table_with_boradcast) {
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>
|
||||
lutsArg(llvm::MutableArrayRef<uint64_t>((uint64_t *)luts, 3 * 4),
|
||||
{3, 4});
|
||||
lutsArg(llvm::MutableArrayRef<uint64_t>((uint64_t *)luts, 3 * 4), {3, 4});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&tArg, &lutsArg});
|
||||
@@ -1276,6 +1275,62 @@ TEST(End2EndJit_HLFHELinalg, matmul_eint_int) {
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// HLFHELinalg matmul_eint_int ////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, matmul_int_eint) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the matrix multiplication of a 3x2 matrix of encrypted integers and a 2x3 matrix of integers.
|
||||
// [ 1, 2, 3]
|
||||
// [ 2, 3, 4]
|
||||
// *
|
||||
// [1,2] [ 5, 8,11]
|
||||
// [3,4] = [11,18,25]
|
||||
// [5,6] [17,28,39]
|
||||
func @main(%a: tensor<3x2xi7>, %b: tensor<2x3x!HLFHE.eint<6>>) -> tensor<3x3x!HLFHE.eint<6>> {
|
||||
%0 = "HLFHELinalg.matmul_int_eint"(%a, %b) : (tensor<3x2xi7>, tensor<2x3x!HLFHE.eint<6>>) -> tensor<3x3x!HLFHE.eint<6>>
|
||||
return %0 : tensor<3x3x!HLFHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t A[3][2]{
|
||||
{1, 2},
|
||||
{3, 4},
|
||||
{5, 6},
|
||||
};
|
||||
const uint8_t B[2][3]{
|
||||
{1, 2, 3},
|
||||
{2, 3, 4},
|
||||
};
|
||||
const uint8_t expected[3][3]{
|
||||
{5, 8, 11},
|
||||
{11, 18, 25},
|
||||
{17, 28, 39},
|
||||
};
|
||||
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
aArg(llvm::ArrayRef<uint8_t>((const uint8_t *)A, 3 * 2), {3, 2});
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>
|
||||
bArg(llvm::ArrayRef<uint8_t>((const uint8_t *)B, 2 * 3), {2, 3});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&aArg, &bArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// linalg.tensor_collapse_shape ///////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1376,4 +1431,4 @@ func @main(%a: tensor<2x8x!HLFHE.eint<6>>) -> tensor<2x2x4x!HLFHE.eint<6>> {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user