mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
@@ -19,6 +19,22 @@ bool verifyEncryptedIntegerAndIntegerInputsConsistency(OpState &op,
|
||||
EncryptedIntegerType &a,
|
||||
IntegerType &b);
|
||||
|
||||
/** Shared error message for all ApplyLookupTable variant Op (several Dialect)
|
||||
* E.g. HLFHE.apply_lookup_table(input, lut)
|
||||
* Message when the lut tensor has an invalid size,
|
||||
* i.e. it cannot accomodate the input elements bitwidth
|
||||
*/
|
||||
template <class Op>
|
||||
void emitErrorBadLutSize(Op &op, std::string lutName, std::string inputName,
|
||||
int expectedSize, int bitWidth) {
|
||||
auto s = op.emitOpError();
|
||||
s << ": `" << lutName << "` (operand #2)"
|
||||
<< " inner dimension should have size " << expectedSize << "(=2^"
|
||||
<< bitWidth << ") to match "
|
||||
<< "`" << inputName << "` (operand #1)"
|
||||
<< " elements bitwidth (" << bitWidth << ")";
|
||||
}
|
||||
|
||||
} // namespace HLFHE
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -342,6 +342,62 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
|
||||
}];
|
||||
}
|
||||
|
||||
def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", []> {
|
||||
let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element, specified by a map.";
|
||||
|
||||
let description = [{
|
||||
Performs for each encrypted indice a lookup on a table of clear integers. Multiple lookup tables are passed, and the application of lookup tables
|
||||
is performed following the broadcasting rules. The precise lookup is specified by a map.
|
||||
|
||||
```mlir
|
||||
// The result of this operation, is a tensor that contains the result of the lookup on different tables.
|
||||
// i.e. %res[i, ..., k] = %luts[ %map[i, ..., k] ][ %t[i, ..., k] ]
|
||||
%res = HLFHELinalg.apply_mapped_lookup_table(%t, %luts, %map): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<DM x ^$p>, tensor<DNx...xD1xindex> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
|
||||
```
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
|
||||
// Returns the lookup of 3x2 matrix of encrypted indices of width 2 on a vector of 2 tables of size 4=2^2 of clear integers.
|
||||
//
|
||||
// [0,1] [0, 1] = [1,2]
|
||||
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
|
||||
// [2,3] [0, 1] = [5,6]
|
||||
"HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : (tensor<3x2x!HLFHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!HLFHE.eint<3>>
|
||||
```
|
||||
|
||||
Others examples:
|
||||
// [0,1] [1, 0] = [3,2]
|
||||
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
|
||||
// [2,3] [1, 0] = [4,7]
|
||||
|
||||
// [0,1] [0, 0] = [1,3]
|
||||
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [1, 1] = [6,0]
|
||||
// [2,3] [1, 0] = [4,7]
|
||||
|
||||
// [0,1] [0] = [1,3]
|
||||
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [1] = [6,0]
|
||||
// [2,3] [0] = [5,7]
|
||||
|
||||
// [0,1] = [1,2]
|
||||
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
|
||||
// [2,3] = [5,6]
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$t,
|
||||
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$luts,
|
||||
Type<And<[TensorOf<[Index]>.predicate, HasStaticShapePred]>>:$map
|
||||
);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::zamalang::HLFHELinalg::verifyApplyMappedLookupTable(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
// Dot product
|
||||
def Dot : HLFHELinalg_Op<"dot_eint_int"> {
|
||||
let summary = "Returns the encrypted dot product between a vector of encrypted integers and a vector of clean integers.";
|
||||
|
||||
@@ -136,6 +136,15 @@ public:
|
||||
llvm::ArrayRef<typename ScalarArgumentT::value_type> value)
|
||||
: TensorLambdaArgument(value, {(int64_t)value.size()}) {}
|
||||
|
||||
template <std::size_t size1, std::size_t size2>
|
||||
TensorLambdaArgument(
|
||||
typename ScalarArgumentT::value_type (&a)[size1][size2]) {
|
||||
dimensions = {size1, size2};
|
||||
auto value = llvm::MutableArrayRef<typename ScalarArgumentT::value_type>(
|
||||
(typename ScalarArgumentT::value_type *)a, size1 * size2);
|
||||
std::copy(value.begin(), value.end(), std::back_inserter(this->value));
|
||||
}
|
||||
|
||||
const std::vector<int64_t> &getDimensions() const { return this->dimensions; }
|
||||
|
||||
// Returns the total number of elements in the tensor. If the number
|
||||
|
||||
@@ -276,6 +276,172 @@ struct HLFHELinalgOpToLinalgGeneric
|
||||
};
|
||||
};
|
||||
|
||||
template <class T> inline mlir::RankedTensorType getRankedTensorType(T v) {
|
||||
return ((mlir::Type)v.getType()).cast<mlir::RankedTensorType>();
|
||||
}
|
||||
|
||||
llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
|
||||
return llvm::SmallVector<llvm::StringRef>(n, "parallel");
|
||||
}
|
||||
|
||||
// This class rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalg.ApplyMappedLookupTableEintOp` that implements the
|
||||
// broadasting rules to an instance of `linalg.generic` with an appropriate
|
||||
// region using `HLFHE.ApplyLookupTableEintOp` operation, an appropriate
|
||||
// specification for the iteration dimensions and appropriate operations
|
||||
// managing the accumulator of `linalg.generic`.
|
||||
//
|
||||
// The current implementation does not rely on 'tensor.extract_slice'
|
||||
// because of a bug in lowering this operation.
|
||||
//
|
||||
// Example:
|
||||
// %res = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map)
|
||||
// : (tensor<2x3x!HLFHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>)
|
||||
// -> tensor<2x3x!HLFHE.eint<2>>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// %init = linalg.init_tensor [2, 3] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
|
||||
// %output = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types
|
||||
// = ["parallel", "parallel"]} ins(%arg0, %arg2 :
|
||||
// tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 :
|
||||
// tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>) {
|
||||
// ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5:
|
||||
// !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
|
||||
// // SHOULD BE
|
||||
// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1]
|
||||
// : tensor<5x4xi64> to tensor<4xi64>
|
||||
// // BUT IS
|
||||
// %i0 = arith.constant 0 : index
|
||||
// ...
|
||||
// %i3 = arith.constant 3 : index
|
||||
// %e0 = tensor.extract %arg5[%lut_idx, %i0] : tensor<5x4xi64>
|
||||
// ...
|
||||
// %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64>
|
||||
// %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64>
|
||||
// %res = "MidLFHE.apply_lookup_table"(%arg3, %[[LUT]])
|
||||
// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32,
|
||||
// levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS =
|
||||
// -1 : i32, polynomialSize = -1 : i32}
|
||||
// : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
// !MidLFHE.glwe<{_,_,_}{2}> linalg.yield %res :
|
||||
// !MidLFHE.glwe<{_,_,_}{2}>
|
||||
// } -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
|
||||
|
||||
namespace HLFHELinalg = mlir::zamalang::HLFHELinalg;
|
||||
|
||||
struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
: public mlir::OpRewritePattern<HLFHELinalg::ApplyMappedLookupTableEintOp> {
|
||||
HLFHELinalgApplyMappedLookupTableToLinalgGeneric(
|
||||
::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<HLFHELinalg::ApplyMappedLookupTableEintOp>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(HLFHELinalg::ApplyMappedLookupTableEintOp mappedLookup,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
namespace arith = mlir::arith;
|
||||
namespace linalg = mlir::linalg;
|
||||
namespace tensor = mlir::tensor;
|
||||
namespace HLFHE = mlir::zamalang::HLFHE;
|
||||
using Values = llvm::SmallVector<mlir::Value>;
|
||||
using Types = llvm::SmallVector<mlir::Type>;
|
||||
using AffineMaps = llvm::SmallVector<mlir::AffineMap>;
|
||||
using sliceArg = llvm::SmallVector<mlir::OpFoldResult>;
|
||||
|
||||
auto input = mappedLookup.t();
|
||||
auto luts = mappedLookup.luts();
|
||||
auto map = mappedLookup.map();
|
||||
|
||||
auto loc = mappedLookup.getLoc();
|
||||
auto tensorTy = getRankedTensorType(input);
|
||||
auto lutsTy = getRankedTensorType(luts);
|
||||
auto resultTy = getRankedTensorType(mappedLookup->getResult(0));
|
||||
auto elementTy = resultTy.getElementType();
|
||||
auto lutElmtTy = lutsTy.getElementType();
|
||||
auto lutsShape = lutsTy.getShape();
|
||||
auto lutSize = lutsShape[lutsShape.size() - 1];
|
||||
auto resultShape = resultTy.getShape();
|
||||
|
||||
auto integer = [&](auto v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
};
|
||||
|
||||
auto _0_ = integer(0);
|
||||
auto _1_ = integer(1);
|
||||
auto lutSizeValue = integer(lutSize);
|
||||
|
||||
// Create the body of the `linalg.generic` op
|
||||
// %arg0 is an element of t (encrypted int)
|
||||
// %arg1 is an element of map (i64)
|
||||
// %arg2 is the output element
|
||||
auto lambdaBlock = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
auto tElmt = blockArgs[0];
|
||||
auto lutIdx = blockArgs[1];
|
||||
auto indexTy = rewriter.getIndexType();
|
||||
|
||||
// %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] :
|
||||
// tensor<NxKxi64> to tensor<Kxi64>
|
||||
mlir::Value lut;
|
||||
const bool WORKAROUND_EXTRACT_SLICE = true;
|
||||
if (!WORKAROUND_EXTRACT_SLICE) {
|
||||
sliceArg offsets{lutIdx, _0_};
|
||||
sliceArg sizes{_1_, lutSizeValue};
|
||||
sliceArg strides{_1_, _1_};
|
||||
auto lutTy = mlir::RankedTensorType::get(
|
||||
{static_cast<int64_t>(lutSize)}, lutElmtTy);
|
||||
lut = nestedBuilder.create<tensor::ExtractSliceOp>(
|
||||
loc, lutTy, luts, offsets, sizes, strides);
|
||||
} else {
|
||||
// WORKAROUND BEGIN
|
||||
// A bug in linalg-bufferize prevents rank reduction in extract_slice
|
||||
// Reshaping does not work either or is too complicated so let's rebuild
|
||||
// the tensor from scratch
|
||||
llvm::SmallVector<mlir::Value> consts;
|
||||
llvm::SmallVector<mlir::Value> extracts;
|
||||
for (int i = 0; i < lutSize; i++) {
|
||||
consts.push_back(
|
||||
// %5 = arith.constant(<i> : index) : index
|
||||
nestedBuilder.create<mlir::arith::ConstantOp>(
|
||||
loc, indexTy, rewriter.getIndexAttr(i)));
|
||||
}
|
||||
for (int i = 0; i < lutSize; i++) {
|
||||
extracts.push_back(
|
||||
// %8 = tensor.extract %luts[<lutIdx>, <i>] : ...
|
||||
nestedBuilder.create<tensor::ExtractOp>(
|
||||
loc, luts, mlir::ValueRange({lutIdx, consts[i]})));
|
||||
}
|
||||
// %12 = tensor.from_elements %8, ... : ...
|
||||
lut = nestedBuilder.create<tensor::FromElementsOp>(loc, extracts);
|
||||
} // WORKAROUND END
|
||||
// %res1 = apply_lookup_table %arg0 %lut
|
||||
auto lookup = nestedBuilder.create<HLFHE::ApplyLookupTableEintOp>(
|
||||
loc, elementTy, tElmt, lut);
|
||||
// linalg.yield %res1 : !HLFHE.eint<2>
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, lookup.getResult());
|
||||
};
|
||||
|
||||
auto output =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, resultShape, elementTy);
|
||||
|
||||
// Create the `linalg.g eneric` op
|
||||
Types resTys{resultTy};
|
||||
Values ins{input, map};
|
||||
Values outs{output};
|
||||
auto indexOfInput = getBroadcastedAffineMap(resultTy, tensorTy, rewriter);
|
||||
AffineMaps affineMaps{indexOfInput, indexOfInput, indexOfInput};
|
||||
auto iteratorTypes = parallelIteratorType(resultShape.size());
|
||||
auto genericOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, resTys, ins, outs, affineMaps, iteratorTypes, lambdaBlock);
|
||||
rewriter.replaceOp(mappedLookup, {genericOp.getResult(0)});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This class rewrite pattern transforms any instance of
|
||||
// operators `HLFHELinalg.ApplyMultiLookupTableEintOp` that implements the
|
||||
// broadasting rules to an instance of `linalg.generic` with an appropriate
|
||||
@@ -847,6 +1013,8 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
|
||||
});
|
||||
patterns.insert<HLFHELinalgApplyMultiLookupTableToLinalgGeneric>(
|
||||
&getContext());
|
||||
patterns.insert<HLFHELinalgApplyMappedLookupTableToLinalgGeneric>(
|
||||
&getContext());
|
||||
patterns.insert<HLFHELinalgZeroToLinalgGenerate>(&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
|
||||
@@ -902,7 +902,8 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
norm2SqEquiv = getSqMANP(matmulIntEintOp, operands);
|
||||
} else if (llvm::isa<
|
||||
mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp,
|
||||
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp>(
|
||||
mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp,
|
||||
mlir::zamalang::HLFHELinalg::ApplyMappedLookupTableEintOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = llvm::APInt{1, 1, false};
|
||||
}
|
||||
|
||||
@@ -108,12 +108,11 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op,
|
||||
|
||||
// Check the shape of l_cst argument
|
||||
auto width = ct.getWidth();
|
||||
auto expectedSize = 1 << width;
|
||||
auto lCstShape = l_cst.getShape();
|
||||
mlir::SmallVector<int64_t, 1> expectedShape{1 << width};
|
||||
mlir::SmallVector<int64_t, 1> expectedShape{expectedSize};
|
||||
if (!l_cst.hasStaticShape(expectedShape)) {
|
||||
op.emitOpError() << " should have as `l_cst` argument a shape of one "
|
||||
"dimension equals to 2^p, where p is the width of the "
|
||||
"`ct` argument.";
|
||||
emitErrorBadLutSize(op, "l_cst", "ct", expectedSize, width);
|
||||
return mlir::failure();
|
||||
}
|
||||
if (!l_cst.getElementType().isInteger(64)) {
|
||||
|
||||
@@ -281,6 +281,85 @@ verifyApplyMultiLookupTable(ApplyMultiLookupTableEintOp &op) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::RankedTensorType getTensorType(::mlir::Value value) {
|
||||
return value.getType().cast<mlir::RankedTensorType>();
|
||||
}
|
||||
|
||||
template <class T> T getElmentType(::mlir::Value value) {
|
||||
auto tTy = getTensorType(value);
|
||||
return tTy.getElementType().cast<T>();
|
||||
}
|
||||
|
||||
mlir::IntegerType getClearElmentType(::mlir::Value value) {
|
||||
return getElmentType<mlir::IntegerType>(value);
|
||||
}
|
||||
|
||||
HLFHE::EncryptedIntegerType getEncryptedElmentType(::mlir::Value value) {
|
||||
using namespace mlir::zamalang::HLFHE;
|
||||
return getElmentType<HLFHE::EncryptedIntegerType>(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult verifyMapHasRightShape(ApplyMappedLookupTableEintOp &op,
|
||||
::mlir::Value &lut_input,
|
||||
::mlir::Value &lut_map) {
|
||||
auto input_shape = getTensorType(lut_input).getShape();
|
||||
auto map_shape = getTensorType(lut_map).getShape();
|
||||
if (input_shape.equals(map_shape)) {
|
||||
return mlir::success();
|
||||
}
|
||||
std::string error;
|
||||
int input_rank = input_shape.size();
|
||||
int map_rank = map_shape.size();
|
||||
std::string input_name = "'t' (operand #1)";
|
||||
std::string map_name = "'lut_map.getName()' (operand #3)";
|
||||
if (input_rank == map_rank) {
|
||||
error = ": " + input_name + " dimensions differs from " + map_name;
|
||||
} else {
|
||||
error = ": " + input_name + " rank (=" + std::to_string(input_rank) +
|
||||
") differs from " + map_name +
|
||||
" rank (=" + std::to_string(map_rank) + ")";
|
||||
}
|
||||
op.emitOpError() << error;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op,
|
||||
::mlir::Value &encryptedIndex,
|
||||
::mlir::Value &luts) {
|
||||
auto index_width = getEncryptedElmentType(encryptedIndex).getWidth();
|
||||
auto actual_lut_size = getTensorType(luts).getShape().back();
|
||||
auto expected_lut_size = 1 << index_width;
|
||||
if (actual_lut_size == expected_lut_size) {
|
||||
return mlir::success();
|
||||
}
|
||||
HLFHE::emitErrorBadLutSize(op, "luts", "ct", expected_lut_size, index_width);
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
verifyApplyMappedLookupTable(ApplyMappedLookupTableEintOp &op) {
|
||||
auto t = op.t();
|
||||
auto luts = op.luts();
|
||||
auto map = op.map();
|
||||
auto result = op.getResult();
|
||||
|
||||
auto t_shape = getTensorType(t).getShape();
|
||||
if (!getTensorType(result).hasStaticShape(t_shape)) {
|
||||
op.emitOpError()
|
||||
<< ": `t` (operand #1) and `map` (operand #2) must have the same shape";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (!getTensorType(map).getElementType().isIndex()) {
|
||||
op.emitOpError()
|
||||
<< ": `map` (operand #3) should contains elements of type `index`";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success(verifyMapHasRightShape(op, t, map).succeeded() &&
|
||||
verifyLutsSize(op, t, luts).succeeded());
|
||||
}
|
||||
|
||||
::mlir::LogicalResult verifyDotEintInt(Dot &op) {
|
||||
if (::mlir::failed(mlir::verifyCompatibleShape(op.lhs().getType(),
|
||||
op.rhs().getType()))) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mlir/IR/Region.h"
|
||||
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
|
||||
|
||||
@@ -146,11 +147,10 @@ mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) {
|
||||
// Check the shape of l_cst argument
|
||||
auto width = ct.getP();
|
||||
auto lCstShape = l_cst.getShape();
|
||||
mlir::SmallVector<int64_t, 1> expectedShape{1 << width};
|
||||
auto expectedSize = 1 << width;
|
||||
mlir::SmallVector<int64_t, 1> expectedShape{expectedSize};
|
||||
if (!l_cst.hasStaticShape(expectedShape)) {
|
||||
op.emitOpError() << "should have as `l_cst` argument a shape of one "
|
||||
"dimension equals to 2^p, where p is the width of the "
|
||||
"`ct` argument.";
|
||||
HLFHE::emitErrorBadLutSize(op, "l_cst", "ct", expectedSize, width);
|
||||
return mlir::failure();
|
||||
}
|
||||
if (!l_cst.getElementType().isInteger(64)) {
|
||||
|
||||
@@ -92,8 +92,7 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
mlir::registerLLVMDialectTranslation(mlirContext);
|
||||
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline =
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
auto optPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =
|
||||
mlir::zamalang::JITLambda::create(funcName, module, optPipeline,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument.
|
||||
// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op : `l_cst` (operand #2) inner dimension should have size 4(=2^2) to match `ct` (operand #1) elements bitwidth (2)
|
||||
func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> {
|
||||
%1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<8xi3>) -> (!HLFHE.eint<2>)
|
||||
return %1: !HLFHE.eint<2>
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s
|
||||
|
||||
|
||||
//CHECK-LABEL: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
//CHECK-NEXT:module {
|
||||
//CHECK-NEXT: func @mapped_lut(%arg0: tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, %[[LUTS:.*]]: tensor<5x4xi64>, %arg2: tensor<2x3xindex>) -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>> {
|
||||
//CHECK-NEXT: %[[V0:.*]] = linalg.init_tensor [2, 3] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg2 : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %[[LUTIDX:.*]]: index, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors
|
||||
//DISABLED-CHECK-NEXT: %[[V3:.*]] = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1, 4] [1, 1] : tensor<5x4xi64> to tensor<4xi64>
|
||||
//WORKAROUND BEGIN
|
||||
//CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
|
||||
//CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
|
||||
//CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index
|
||||
//CHECK-NEXT: %[[C3:.*]] = arith.constant 3 : index
|
||||
//CHECK-NEXT: %[[E0:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C0]]] : tensor<5x4xi64>
|
||||
//CHECK-NEXT: %[[E1:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C1]]] : tensor<5x4xi64>
|
||||
//CHECK-NEXT: %[[E2:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C2]]] : tensor<5x4xi64>
|
||||
//CHECK-NEXT: %[[E3:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C3]]] : tensor<5x4xi64>
|
||||
//CHECK-NEXT: %[[LUT:.*]] = tensor.from_elements %[[E0]], %[[E1]], %[[E2]], %[[E3]] : tensor<4xi64>
|
||||
//WORKAROUND END
|
||||
//CHECK-NEXT: %[[V4:.*]] = "MidLFHE.apply_lookup_table"(%arg3, %[[LUT]]) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 : i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: linalg.yield %[[V4]] : !MidLFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: } -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: return %[[V1]] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: }
|
||||
func @mapped_lut(%t: tensor<2x3x!HLFHE.eint<2>>, %luts: tensor<5x4xi64>, %map: tensor<2x3xindex>) -> tensor<2x3x!HLFHE.eint<2>> {
|
||||
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map): (tensor<2x3x!HLFHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>) -> tensor<2x3x!HLFHE.eint<2>>
|
||||
return %1: tensor<2x3x!HLFHE.eint<2>>
|
||||
}
|
||||
@@ -158,6 +158,47 @@ func @apply_multi_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tenso
|
||||
return %1: tensor<2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.apply_mapped_lookup_table
|
||||
/////////////////////////////////////////////////
|
||||
func @apply_mapped_lookup_table_bad_lut_size_127_vs_128(
|
||||
%input: tensor<2x3x4x!HLFHE.eint<7>>,
|
||||
%luts: tensor<127xi64>,
|
||||
%map: tensor<2x3x4xindex>
|
||||
) -> tensor<2x3x4x!HLFHE.eint<7>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.apply_mapped_lookup_table' op : `luts` (operand #2) inner dimension should have size 128(=2^7) to match `ct` (operand #1) elements bitwidth (7)}}
|
||||
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<127xi64>, tensor<2x3x4xindex>) -> tensor<2x3x4x!HLFHE.eint<7>>
|
||||
return %1: tensor<2x3x4x!HLFHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @apply_mapped_lookup_table_bad_map_size(
|
||||
%input: tensor<2x3x4x!HLFHE.eint<7>>,
|
||||
%luts: tensor<128xi64>,
|
||||
%map: tensor<2x3xindex>
|
||||
) -> tensor<2x3x4x!HLFHE.eint<7>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.apply_mapped_lookup_table' op : 't' (operand #1) rank (=3) differs from 'lut_map.getName()' (operand #3) rank (=2)}}
|
||||
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<128xi64>, tensor<2x3xindex>) -> tensor<2x3x4x!HLFHE.eint<7>>
|
||||
return %1: tensor<2x3x4x!HLFHE.eint<7>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @apply_mapped_lookup_table_bad_map_elmt_type(
|
||||
%input: tensor<2x3x4x!HLFHE.eint<7>>,
|
||||
%luts: tensor<128xi64>,
|
||||
%map: tensor<2x3xindex>
|
||||
) -> tensor<2x3x4x!HLFHE.eint<7>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.apply_mapped_lookup_table' op : 't' (operand #1) rank (=3) differs from 'lut_map.getName()' (operand #3) rank (=2)}}
|
||||
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<128xi64>, tensor<2x3xindex>) -> tensor<2x3x4x!HLFHE.eint<7>>
|
||||
return %1: tensor<2x3x4x!HLFHE.eint<7>>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
@@ -288,6 +288,22 @@ func @apply_multi_lookup_table_broadcast(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %a
|
||||
return %1: tensor<2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.apply_mapped_lookup_table
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// CHECK-LABEL: func @apply_mapped_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<7>>, %arg1: tensor<10x128xi64>, %arg2: tensor<2x3x4xindex>) -> tensor<2x3x4x!HLFHE.eint<7>> {
|
||||
func @apply_mapped_lookup_table(
|
||||
%input: tensor<2x3x4x!HLFHE.eint<7>>,
|
||||
%luts: tensor<10x128xi64>,
|
||||
%map: tensor<2x3x4xindex>
|
||||
) -> tensor<2x3x4x!HLFHE.eint<7>> {
|
||||
// CHECK-NEXT: %0 = "HLFHELinalg.apply_mapped_lookup_table"(%arg0, %arg1, %arg2) : (tensor<2x3x4x!HLFHE.eint<7>>, tensor<10x128xi64>, tensor<2x3x4xindex>) -> tensor<2x3x4x!HLFHE.eint<7>>
|
||||
// CHECK-NEXT: return %0 : tensor<2x3x4x!HLFHE.eint<7>>
|
||||
%0 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<10x128xi64>, tensor<2x3x4xindex>) -> (tensor<2x3x4x!HLFHE.eint<7>>)
|
||||
return %0: tensor<2x3x4x!HLFHE.eint<7>>
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.dot_eint_int
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
// Bad dimension of the lookup table
|
||||
func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> {
|
||||
// expected-error @+1 {{'MidLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument}}
|
||||
// expected-error @+1 {{'MidLFHE.apply_lookup_table' op : `l_cst` (operand #2) inner dimension should have size 128(=2^7) to match `ct` (operand #1) elements bitwidth (7)}}
|
||||
%1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {glweDimension = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32, outputSizeKS = 600 : i32}: (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<4xi2>) -> (!MidLFHE.glwe<{512,10,64}{2}>)
|
||||
return %1: !MidLFHE.glwe<{512,10,64}{2}>
|
||||
}
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
#include "end_to_end_jit_test.h"
|
||||
|
||||
namespace Z = mlir::zamalang;
|
||||
template <class Elmt>
|
||||
using tensorArgTy = Z::TensorLambdaArgument<Z::IntLambdaArgument<Elmt>>;
|
||||
|
||||
#define GET_3D(tensor, i, j, k, di, dj, dk) (tensor)[i * dj * dk + j * dk + k]
|
||||
#define GET_2D(tensor, i, j, di, dj) (tensor)[i * dj + j]
|
||||
|
||||
@@ -1149,6 +1153,110 @@ TEST(End2EndJit_HLFHELinalg, apply_multi_lookup_table_with_boradcast) {
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// HLFHELinalg apply_mapped_lookup_table /////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, apply_mapped_lookup_table_sequential) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers.
|
||||
func @main(%t: tensor<3x3x!HLFHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>> {
|
||||
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) :
|
||||
(tensor<3x3x!HLFHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>>
|
||||
return %1: tensor<3x3x!HLFHE.eint<2>>
|
||||
}
|
||||
)XXX");
|
||||
uint8_t t[3][3]{
|
||||
{0, 1, 2},
|
||||
{3, 0, 1},
|
||||
{2, 3, 0},
|
||||
};
|
||||
uint64_t luts[9][4]{
|
||||
{3, 0, 0, 0}, {0, 3, 0, 0}, {0, 0, 3, 0},
|
||||
{0, 0, 0, 3}, {3, 0, 0, 0}, {0, 3, 0, 0},
|
||||
{0, 0, 3, 0}, {0, 0, 0, 3}, {3, 0, 0, 0},
|
||||
};
|
||||
uint64_t map[3][3]{
|
||||
{0, 1, 2},
|
||||
{3, 4, 5},
|
||||
{6, 7, 8},
|
||||
};
|
||||
uint8_t expected[3][3]{
|
||||
{3, 3, 3},
|
||||
{3, 3, 3},
|
||||
{3, 3, 3},
|
||||
};
|
||||
|
||||
tensorArgTy<uint8_t> tArg(t);
|
||||
tensorArgTy<uint64_t> lutsArg(luts), mapArg(map);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>(
|
||||
{&tArg, &lutsArg, &mapArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), 3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_HLFHELinalg, apply_mapped_lookup_table_same_lut) {
|
||||
|
||||
mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
// Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers.
|
||||
func @main(%t: tensor<3x3x!HLFHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>> {
|
||||
%1 = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) :
|
||||
(tensor<3x3x!HLFHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>>
|
||||
return %1: tensor<3x3x!HLFHE.eint<2>>
|
||||
}
|
||||
)XXX");
|
||||
uint8_t t[3][3]{
|
||||
{0, 1, 2},
|
||||
{3, 0, 1},
|
||||
{2, 3, 0},
|
||||
};
|
||||
uint64_t luts[9][4]{
|
||||
{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0},
|
||||
{0, 0, 0, 0}, {1, 2, 3, 1}, {0, 0, 0, 0},
|
||||
{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0},
|
||||
};
|
||||
uint64_t map[3][3]{
|
||||
{4, 4, 4},
|
||||
{4, 4, 4},
|
||||
{4, 4, 4},
|
||||
};
|
||||
uint8_t expected[3][3]{
|
||||
{1, 2, 3},
|
||||
{1, 1, 2},
|
||||
{3, 1, 1},
|
||||
};
|
||||
|
||||
tensorArgTy<uint8_t> tArg(t);
|
||||
tensorArgTy<uint64_t> lutsArg(luts), mapArg(map);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>(
|
||||
{&tArg, &lutsArg, &mapArg});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), 3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], expected[i][j])
|
||||
<< ", at pos(" << i << "," << j << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// HLFHELinalg dot_eint_int ///////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user