feat(HLFHELinalg): add apply_mapped_table_lookup

Resolves #182
This commit is contained in:
rudy
2021-12-10 16:47:04 +01:00
committed by rudy-6-4
parent 81189ceaa9
commit d8fee32cea
15 changed files with 536 additions and 13 deletions

View File

@@ -19,6 +19,22 @@ bool verifyEncryptedIntegerAndIntegerInputsConsistency(OpState &op,
EncryptedIntegerType &a,
IntegerType &b);
/** Shared error message for all ApplyLookupTable variant Op (several Dialect)
* E.g. HLFHE.apply_lookup_table(input, lut)
* Message when the lut tensor has an invalid size,
* i.e. it cannot accomodate the input elements bitwidth
*/
template <class Op>
void emitErrorBadLutSize(Op &op, std::string lutName, std::string inputName,
int expectedSize, int bitWidth) {
auto s = op.emitOpError();
s << ": `" << lutName << "` (operand #2)"
<< " inner dimension should have size " << expectedSize << "(=2^"
<< bitWidth << ") to match "
<< "`" << inputName << "` (operand #1)"
<< " elements bitwidth (" << bitWidth << ")";
}
} // namespace HLFHE
} // namespace zamalang
} // namespace mlir

View File

@@ -342,6 +342,62 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
}];
}
def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", []> {
let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element, specified by a map.";
let description = [{
Performs for each encrypted indice a lookup on a table of clear integers. Multiple lookup tables are passed, and the application of lookup tables
is performed following the broadcasting rules. The precise lookup is specified by a map.
```mlir
// The result of this operation, is a tensor that contains the result of the lookup on different tables.
// i.e. %res[i, ..., k] = %luts[ %map[i, ..., k] ][ %t[i, ..., k] ]
%res = HLFHELinalg.apply_mapped_lookup_table(%t, %luts, %map): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<DM x ^$p>, tensor<DNx...xD1xindex> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
```
Examples:
```mlir
// Returns the lookup of 3x2 matrix of encrypted indices of width 2 on a vector of 2 tables of size 4=2^2 of clear integers.
//
// [0,1] [0, 1] = [1,2]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
// [2,3] [0, 1] = [5,6]
"HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : (tensor<3x2x!HLFHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!HLFHE.eint<3>>
```
Others examples:
// [0,1] [1, 0] = [3,2]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
// [2,3] [1, 0] = [4,7]
// [0,1] [0, 0] = [1,3]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [1, 1] = [6,0]
// [2,3] [1, 0] = [4,7]
// [0,1] [0] = [1,3]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [1] = [6,0]
// [2,3] [0] = [5,7]
// [0,1] = [1,2]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
// [2,3] = [5,6]
}];
let arguments = (ins
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$t,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$luts,
Type<And<[TensorOf<[Index]>.predicate, HasStaticShapePred]>>:$map
);
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let verifier = [{
return ::mlir::zamalang::HLFHELinalg::verifyApplyMappedLookupTable(*this);
}];
}
// Dot product
def Dot : HLFHELinalg_Op<"dot_eint_int"> {
let summary = "Returns the encrypted dot product between a vector of encrypted integers and a vector of clean integers.";

View File

@@ -136,6 +136,15 @@ public:
llvm::ArrayRef<typename ScalarArgumentT::value_type> value)
: TensorLambdaArgument(value, {(int64_t)value.size()}) {}
template <std::size_t size1, std::size_t size2>
TensorLambdaArgument(
typename ScalarArgumentT::value_type (&a)[size1][size2]) {
dimensions = {size1, size2};
auto value = llvm::MutableArrayRef<typename ScalarArgumentT::value_type>(
(typename ScalarArgumentT::value_type *)a, size1 * size2);
std::copy(value.begin(), value.end(), std::back_inserter(this->value));
}
const std::vector<int64_t> &getDimensions() const { return this->dimensions; }
// Returns the total number of elements in the tensor. If the number

View File

@@ -276,6 +276,172 @@ struct HLFHELinalgOpToLinalgGeneric
};
};
template <class T> inline mlir::RankedTensorType getRankedTensorType(T v) {
return ((mlir::Type)v.getType()).cast<mlir::RankedTensorType>();
}
llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
return llvm::SmallVector<llvm::StringRef>(n, "parallel");
}
// This class rewrite pattern transforms any instance of
// operators `HLFHELinalg.ApplyMappedLookupTableEintOp` that implements the
// broadasting rules to an instance of `linalg.generic` with an appropriate
// region using `HLFHE.ApplyLookupTableEintOp` operation, an appropriate
// specification for the iteration dimensions and appropriate operations
// managing the accumulator of `linalg.generic`.
//
// The current implementation does not rely on 'tensor.extract_slice'
// because of a bug in lowering this operation.
//
// Example:
// %res = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map)
// : (tensor<2x3x!HLFHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>)
// -> tensor<2x3x!HLFHE.eint<2>>
//
// becomes:
//
// #map = affine_map<(d0, d1) -> (d0, d1)>
// %init = linalg.init_tensor [2, 3] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
// %output = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types
// = ["parallel", "parallel"]} ins(%arg0, %arg2 :
// tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 :
// tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>) {
// ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5:
// !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
// // SHOULD BE
// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1]
// : tensor<5x4xi64> to tensor<4xi64>
// // BUT IS
// %i0 = arith.constant 0 : index
// ...
// %i3 = arith.constant 3 : index
// %e0 = tensor.extract %arg5[%lut_idx, %i0] : tensor<5x4xi64>
// ...
// %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64>
// %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64>
// %res = "MidLFHE.apply_lookup_table"(%arg3, %[[LUT]])
// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32,
// levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS =
// -1 : i32, polynomialSize = -1 : i32}
// : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
// !MidLFHE.glwe<{_,_,_}{2}> linalg.yield %res :
// !MidLFHE.glwe<{_,_,_}{2}>
// } -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
namespace HLFHELinalg = mlir::zamalang::HLFHELinalg;
struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric
: public mlir::OpRewritePattern<HLFHELinalg::ApplyMappedLookupTableEintOp> {
HLFHELinalgApplyMappedLookupTableToLinalgGeneric(
::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<HLFHELinalg::ApplyMappedLookupTableEintOp>(
context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(HLFHELinalg::ApplyMappedLookupTableEintOp mappedLookup,
::mlir::PatternRewriter &rewriter) const override {
namespace arith = mlir::arith;
namespace linalg = mlir::linalg;
namespace tensor = mlir::tensor;
namespace HLFHE = mlir::zamalang::HLFHE;
using Values = llvm::SmallVector<mlir::Value>;
using Types = llvm::SmallVector<mlir::Type>;
using AffineMaps = llvm::SmallVector<mlir::AffineMap>;
using sliceArg = llvm::SmallVector<mlir::OpFoldResult>;
auto input = mappedLookup.t();
auto luts = mappedLookup.luts();
auto map = mappedLookup.map();
auto loc = mappedLookup.getLoc();
auto tensorTy = getRankedTensorType(input);
auto lutsTy = getRankedTensorType(luts);
auto resultTy = getRankedTensorType(mappedLookup->getResult(0));
auto elementTy = resultTy.getElementType();
auto lutElmtTy = lutsTy.getElementType();
auto lutsShape = lutsTy.getShape();
auto lutSize = lutsShape[lutsShape.size() - 1];
auto resultShape = resultTy.getShape();
auto integer = [&](auto v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
};
auto _0_ = integer(0);
auto _1_ = integer(1);
auto lutSizeValue = integer(lutSize);
// Create the body of the `linalg.generic` op
// %arg0 is an element of t (encrypted int)
// %arg1 is an element of map (i64)
// %arg2 is the output element
auto lambdaBlock = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
auto tElmt = blockArgs[0];
auto lutIdx = blockArgs[1];
auto indexTy = rewriter.getIndexType();
// %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] :
// tensor<NxKxi64> to tensor<Kxi64>
mlir::Value lut;
const bool WORKAROUND_EXTRACT_SLICE = true;
if (!WORKAROUND_EXTRACT_SLICE) {
sliceArg offsets{lutIdx, _0_};
sliceArg sizes{_1_, lutSizeValue};
sliceArg strides{_1_, _1_};
auto lutTy = mlir::RankedTensorType::get(
{static_cast<int64_t>(lutSize)}, lutElmtTy);
lut = nestedBuilder.create<tensor::ExtractSliceOp>(
loc, lutTy, luts, offsets, sizes, strides);
} else {
// WORKAROUND BEGIN
// A bug in linalg-bufferize prevents rank reduction in extract_slice
// Reshaping does not work either or is too complicated so let's rebuild
// the tensor from scratch
llvm::SmallVector<mlir::Value> consts;
llvm::SmallVector<mlir::Value> extracts;
for (int i = 0; i < lutSize; i++) {
consts.push_back(
// %5 = arith.constant(<i> : index) : index
nestedBuilder.create<mlir::arith::ConstantOp>(
loc, indexTy, rewriter.getIndexAttr(i)));
}
for (int i = 0; i < lutSize; i++) {
extracts.push_back(
// %8 = tensor.extract %luts[<lutIdx>, <i>] : ...
nestedBuilder.create<tensor::ExtractOp>(
loc, luts, mlir::ValueRange({lutIdx, consts[i]})));
}
// %12 = tensor.from_elements %8, ... : ...
lut = nestedBuilder.create<tensor::FromElementsOp>(loc, extracts);
} // WORKAROUND END
// %res1 = apply_lookup_table %arg0 %lut
auto lookup = nestedBuilder.create<HLFHE::ApplyLookupTableEintOp>(
loc, elementTy, tElmt, lut);
// linalg.yield %res1 : !HLFHE.eint<2>
nestedBuilder.create<linalg::YieldOp>(loc, lookup.getResult());
};
auto output =
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementTy);
// Create the `linalg.g eneric` op
Types resTys{resultTy};
Values ins{input, map};
Values outs{output};
auto indexOfInput = getBroadcastedAffineMap(resultTy, tensorTy, rewriter);
AffineMaps affineMaps{indexOfInput, indexOfInput, indexOfInput};
auto iteratorTypes = parallelIteratorType(resultShape.size());
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, resTys, ins, outs, affineMaps, iteratorTypes, lambdaBlock);
rewriter.replaceOp(mappedLookup, {genericOp.getResult(0)});
return ::mlir::success();
};
};
// This class rewrite pattern transforms any instance of
// operators `HLFHELinalg.ApplyMultiLookupTableEintOp` that implements the
// broadasting rules to an instance of `linalg.generic` with an appropriate
@@ -847,6 +1013,8 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
});
patterns.insert<HLFHELinalgApplyMultiLookupTableToLinalgGeneric>(
&getContext());
patterns.insert<HLFHELinalgApplyMappedLookupTableToLinalgGeneric>(
&getContext());
patterns.insert<HLFHELinalgZeroToLinalgGenerate>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))

View File

@@ -902,7 +902,8 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
norm2SqEquiv = getSqMANP(matmulIntEintOp, operands);
} else if (llvm::isa<
mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp,
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp>(
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp,
mlir::zamalang::HLFHELinalg::ApplyMappedLookupTableEintOp>(
op)) {
norm2SqEquiv = llvm::APInt{1, 1, false};
}

View File

@@ -108,12 +108,11 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
// Check the shape of l_cst argument
auto width = ct.getWidth();
auto expectedSize = 1 << width;
auto lCstShape = l_cst.getShape();
mlir::SmallVector<int64_t, 1> expectedShape{1 << width};
mlir::SmallVector<int64_t, 1> expectedShape{expectedSize};
if (!l_cst.hasStaticShape(expectedShape)) {
op.emitOpError() << " should have as `l_cst` argument a shape of one "
"dimension equals to 2^p, where p is the width of the "
"`ct` argument.";
emitErrorBadLutSize(op, "l_cst", "ct", expectedSize, width);
return mlir::failure();
}
if (!l_cst.getElementType().isInteger(64)) {

View File

@@ -281,6 +281,85 @@ verifyApplyMultiLookupTable(ApplyMultiLookupTableEintOp &op) {
return mlir::success();
}
mlir::RankedTensorType getTensorType(::mlir::Value value) {
return value.getType().cast<mlir::RankedTensorType>();
}
template <class T> T getElmentType(::mlir::Value value) {
auto tTy = getTensorType(value);
return tTy.getElementType().cast<T>();
}
mlir::IntegerType getClearElmentType(::mlir::Value value) {
return getElmentType<mlir::IntegerType>(value);
}
HLFHE::EncryptedIntegerType getEncryptedElmentType(::mlir::Value value) {
using namespace mlir::zamalang::HLFHE;
return getElmentType<HLFHE::EncryptedIntegerType>(value);
}
mlir::LogicalResult verifyMapHasRightShape(ApplyMappedLookupTableEintOp &op,
::mlir::Value &lut_input,
::mlir::Value &lut_map) {
auto input_shape = getTensorType(lut_input).getShape();
auto map_shape = getTensorType(lut_map).getShape();
if (input_shape.equals(map_shape)) {
return mlir::success();
}
std::string error;
int input_rank = input_shape.size();
int map_rank = map_shape.size();
std::string input_name = "'t' (operand #1)";
std::string map_name = "'lut_map.getName()' (operand #3)";
if (input_rank == map_rank) {
error = ": " + input_name + " dimensions differs from " + map_name;
} else {
error = ": " + input_name + " rank (=" + std::to_string(input_rank) +
") differs from " + map_name +
" rank (=" + std::to_string(map_rank) + ")";
}
op.emitOpError() << error;
return mlir::failure();
}
mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op,
::mlir::Value &encryptedIndex,
::mlir::Value &luts) {
auto index_width = getEncryptedElmentType(encryptedIndex).getWidth();
auto actual_lut_size = getTensorType(luts).getShape().back();
auto expected_lut_size = 1 << index_width;
if (actual_lut_size == expected_lut_size) {
return mlir::success();
}
HLFHE::emitErrorBadLutSize(op, "luts", "ct", expected_lut_size, index_width);
return mlir::failure();
}
mlir::LogicalResult
verifyApplyMappedLookupTable(ApplyMappedLookupTableEintOp &op) {
auto t = op.t();
auto luts = op.luts();
auto map = op.map();
auto result = op.getResult();
auto t_shape = getTensorType(t).getShape();
if (!getTensorType(result).hasStaticShape(t_shape)) {
op.emitOpError()
<< ": `t` (operand #1) and `map` (operand #2) must have the same shape";
return mlir::failure();
}
if (!getTensorType(map).getElementType().isIndex()) {
op.emitOpError()
<< ": `map` (operand #3) should contains elements of type `index`";
return mlir::failure();
}
return mlir::success(verifyMapHasRightShape(op, t, map).succeeded() &&
verifyLutsSize(op, t, luts).succeeded());
}
::mlir::LogicalResult verifyDotEintInt(Dot &op) {
if (::mlir::failed(mlir::verifyCompatibleShape(op.lhs().getType(),
op.rhs().getType()))) {

View File

@@ -1,5 +1,6 @@
#include "mlir/IR/Region.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
@@ -146,11 +147,10 @@ mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) {
// Check the shape of l_cst argument
auto width = ct.getP();
auto lCstShape = l_cst.getShape();
mlir::SmallVector<int64_t, 1> expectedShape{1 << width};
auto expectedSize = 1 << width;
mlir::SmallVector<int64_t, 1> expectedShape{expectedSize};
if (!l_cst.hasStaticShape(expectedShape)) {
op.emitOpError() << "should have as `l_cst` argument a shape of one "
"dimension equals to 2^p, where p is the width of the "
"`ct` argument.";
HLFHE::emitErrorBadLutSize(op, "l_cst", "ct", expectedSize, width);
return mlir::failure();
}
if (!l_cst.getElementType().isInteger(64)) {

View File

@@ -92,8 +92,7 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
llvm::InitializeNativeTargetAsmPrinter();
mlir::registerLLVMDialectTranslation(mlirContext);
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
auto optPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =
mlir::zamalang::JITLambda::create(funcName, module, optPipeline,

View File

@@ -1,6 +1,6 @@
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument.
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op : `l_cst` (operand #2) inner dimension should have size 4(=2^2) to match `ct` (operand #1) elements bitwidth (2)
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> {
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<8xi3>) -> (!HLFHE.eint<2>)
return %1: !HLFHE.eint<2>

View File

@@ -0,0 +1,31 @@
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s
//CHECK-LABEL: #map = affine_map<(d0, d1) -> (d0, d1)>
//CHECK-NEXT:module {
//CHECK-NEXT: func @mapped_lut(%arg0: tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, %[[LUTS:.*]]: tensor<5x4xi64>, %arg2: tensor<2x3xindex>) -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>> {
//CHECK-NEXT: %[[V0:.*]] = linalg.init_tensor [2, 3] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg2 : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %[[LUTIDX:.*]]: index, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
//DISABLED-CHECK-NEXT: %[[V3:.*]] = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1, 4] [1, 1] : tensor<5x4xi64> to tensor<4xi64>
//WORKAROUND BEGIN
//CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
//CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
//CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index
//CHECK-NEXT: %[[C3:.*]] = arith.constant 3 : index
//CHECK-NEXT: %[[E0:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C0]]] : tensor<5x4xi64>
//CHECK-NEXT: %[[E1:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C1]]] : tensor<5x4xi64>
//CHECK-NEXT: %[[E2:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C2]]] : tensor<5x4xi64>
//CHECK-NEXT: %[[E3:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C3]]] : tensor<5x4xi64>
//CHECK-NEXT: %[[LUT:.*]] = tensor.from_elements %[[E0]], %[[E1]], %[[E2]], %[[E3]] : tensor<4xi64>
//WORKAROUND END
//CHECK-NEXT: %[[V4:.*]] = "MidLFHE.apply_lookup_table"(%arg3, %[[LUT]]) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 : i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %[[V4]] : !MidLFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: return %[[V1]] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: }
func @mapped_lut(%t: tensor<2x3x!HLFHE.eint<2>>, %luts: tensor<5x4xi64>, %map: tensor<2x3xindex>) -> tensor<2x3x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map): (tensor<2x3x!HLFHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>) -> tensor<2x3x!HLFHE.eint<2>>
return %1: tensor<2x3x!HLFHE.eint<2>>
}

View File

@@ -158,6 +158,47 @@ func @apply_multi_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tenso
return %1: tensor<2x3x4x!HLFHE.eint<2>>
}
// -----
/////////////////////////////////////////////////
// HLFHELinalg.apply_mapped_lookup_table
/////////////////////////////////////////////////
func @apply_mapped_lookup_table_bad_lut_size_127_vs_128(
%input: tensor<2x3x4x!HLFHE.eint<7>>,
%luts: tensor<127xi64>,
%map: tensor<2x3x4xindex>
) -> tensor<2x3x4x!HLFHE.eint<7>> {
// expected-error @+1 {{'HLFHELinalg.apply_mapped_lookup_table' op : `luts` (operand #2) inner dimension should have size 128(=2^7) to match `ct` (operand #1) elements bitwidth (7)}}
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<127xi64>, tensor<2x3x4xindex>) -> tensor<2x3x4x!HLFHE.eint<7>>
return %1: tensor<2x3x4x!HLFHE.eint<7>>
}
// -----
func @apply_mapped_lookup_table_bad_map_size(
%input: tensor<2x3x4x!HLFHE.eint<7>>,
%luts: tensor<128xi64>,
%map: tensor<2x3xindex>
) -> tensor<2x3x4x!HLFHE.eint<7>> {
// expected-error @+1 {{'HLFHELinalg.apply_mapped_lookup_table' op : 't' (operand #1) rank (=3) differs from 'lut_map.getName()' (operand #3) rank (=2)}}
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<128xi64>, tensor<2x3xindex>) -> tensor<2x3x4x!HLFHE.eint<7>>
return %1: tensor<2x3x4x!HLFHE.eint<7>>
}
// -----
func @apply_mapped_lookup_table_bad_map_elmt_type(
%input: tensor<2x3x4x!HLFHE.eint<7>>,
%luts: tensor<128xi64>,
%map: tensor<2x3xindex>
) -> tensor<2x3x4x!HLFHE.eint<7>> {
// expected-error @+1 {{'HLFHELinalg.apply_mapped_lookup_table' op : 't' (operand #1) rank (=3) differs from 'lut_map.getName()' (operand #3) rank (=2)}}
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<128xi64>, tensor<2x3xindex>) -> tensor<2x3x4x!HLFHE.eint<7>>
return %1: tensor<2x3x4x!HLFHE.eint<7>>
}
// -----
/////////////////////////////////////////////////

View File

@@ -288,6 +288,22 @@ func @apply_multi_lookup_table_broadcast(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %a
return %1: tensor<2x3x4x!HLFHE.eint<2>>
}
/////////////////////////////////////////////////
// HLFHELinalg.apply_mapped_lookup_table
/////////////////////////////////////////////////
// CHECK-LABEL: func @apply_mapped_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<7>>, %arg1: tensor<10x128xi64>, %arg2: tensor<2x3x4xindex>) -> tensor<2x3x4x!HLFHE.eint<7>> {
func @apply_mapped_lookup_table(
%input: tensor<2x3x4x!HLFHE.eint<7>>,
%luts: tensor<10x128xi64>,
%map: tensor<2x3x4xindex>
) -> tensor<2x3x4x!HLFHE.eint<7>> {
// CHECK-NEXT: %0 = "HLFHELinalg.apply_mapped_lookup_table"(%arg0, %arg1, %arg2) : (tensor<2x3x4x!HLFHE.eint<7>>, tensor<10x128xi64>, tensor<2x3x4xindex>) -> tensor<2x3x4x!HLFHE.eint<7>>
// CHECK-NEXT: return %0 : tensor<2x3x4x!HLFHE.eint<7>>
%0 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<10x128xi64>, tensor<2x3x4xindex>) -> (tensor<2x3x4x!HLFHE.eint<7>>)
return %0: tensor<2x3x4x!HLFHE.eint<7>>
}
/////////////////////////////////////////////////
// HLFHELinalg.dot_eint_int
/////////////////////////////////////////////////

View File

@@ -2,7 +2,7 @@
// Bad dimension of the lookup table
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {
// expected-error @+1 {{'MidLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument}}
// expected-error @+1 {{'MidLFHE.apply_lookup_table' op : `l_cst` (operand #2) inner dimension should have size 128(=2^7) to match `ct` (operand #1) elements bitwidth (7)}}
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {glweDimension = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32, outputSizeKS = 600 : i32}: (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<4xi2>) -> (!MidLFHE.glwe<{512,10,64}{2}>)
return %1: !MidLFHE.glwe<{512,10,64}{2}>
}

View File

@@ -1,5 +1,9 @@
#include "end_to_end_jit_test.h"
namespace Z = mlir::zamalang;
template <class Elmt>
using tensorArgTy = Z::TensorLambdaArgument<Z::IntLambdaArgument<Elmt>>;
#define GET_3D(tensor, i, j, k, di, dj, dk) (tensor)[i * dj * dk + j * dk + k]
#define GET_2D(tensor, i, j, di, dj) (tensor)[i * dj + j]
@@ -1149,6 +1153,110 @@ TEST(End2EndJit_HLFHELinalg, apply_multi_lookup_table_with_boradcast) {
}
}
///////////////////////////////////////////////////////////////////////////////
// HLFHELinalg apply_mapped_lookup_table /////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_HLFHELinalg, apply_mapped_lookup_table_sequential) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers.
func @main(%t: tensor<3x3x!HLFHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) :
(tensor<3x3x!HLFHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>>
return %1: tensor<3x3x!HLFHE.eint<2>>
}
)XXX");
uint8_t t[3][3]{
{0, 1, 2},
{3, 0, 1},
{2, 3, 0},
};
uint64_t luts[9][4]{
{3, 0, 0, 0}, {0, 3, 0, 0}, {0, 0, 3, 0},
{0, 0, 0, 3}, {3, 0, 0, 0}, {0, 3, 0, 0},
{0, 0, 3, 0}, {0, 0, 0, 3}, {3, 0, 0, 0},
};
uint64_t map[3][3]{
{0, 1, 2},
{3, 4, 5},
{6, 7, 8},
};
uint8_t expected[3][3]{
{3, 3, 3},
{3, 3, 3},
{3, 3, 3},
};
tensorArgTy<uint8_t> tArg(t);
tensorArgTy<uint64_t> lutsArg(luts), mapArg(map);
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>(
{&tArg, &lutsArg, &mapArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), 3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
TEST(End2EndJit_HLFHELinalg, apply_mapped_lookup_table_same_lut) {
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers.
func @main(%t: tensor<3x3x!HLFHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>> {
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) :
(tensor<3x3x!HLFHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>>
return %1: tensor<3x3x!HLFHE.eint<2>>
}
)XXX");
uint8_t t[3][3]{
{0, 1, 2},
{3, 0, 1},
{2, 3, 0},
};
uint64_t luts[9][4]{
{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0},
{0, 0, 0, 0}, {1, 2, 3, 1}, {0, 0, 0, 0},
{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0},
};
uint64_t map[3][3]{
{4, 4, 4},
{4, 4, 4},
{4, 4, 4},
};
uint8_t expected[3][3]{
{1, 2, 3},
{1, 1, 2},
{3, 1, 1},
};
tensorArgTy<uint8_t> tArg(t);
tensorArgTy<uint64_t> lutsArg(luts), mapArg(map);
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>(
{&tArg, &lutsArg, &mapArg});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), 3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
<< ", at pos(" << i << "," << j << ")";
}
}
}
///////////////////////////////////////////////////////////////////////////////
// HLFHELinalg dot_eint_int ///////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////