feat(HLFHELinalg): add apply_mapped_table_lookup

Resolves #182
This commit is contained in:
rudy
2021-12-10 16:47:04 +01:00
committed by rudy-6-4
parent 81189ceaa9
commit d8fee32cea
15 changed files with 536 additions and 13 deletions

View File

@@ -19,6 +19,22 @@ bool verifyEncryptedIntegerAndIntegerInputsConsistency(OpState &op,
EncryptedIntegerType &a,
IntegerType &b);
/** Shared error message for all ApplyLookupTable variant Op (several Dialect)
* E.g. HLFHE.apply_lookup_table(input, lut)
* Message when the lut tensor has an invalid size,
* i.e. it cannot accomodate the input elements bitwidth
*/
template <class Op>
void emitErrorBadLutSize(Op &op, std::string lutName, std::string inputName,
int expectedSize, int bitWidth) {
auto s = op.emitOpError();
s << ": `" << lutName << "` (operand #2)"
<< " inner dimension should have size " << expectedSize << "(=2^"
<< bitWidth << ") to match "
<< "`" << inputName << "` (operand #1)"
<< " elements bitwidth (" << bitWidth << ")";
}
} // namespace HLFHE
} // namespace zamalang
} // namespace mlir

View File

@@ -342,6 +342,62 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []>
}];
}
def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", []> {
let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element, specified by a map.";
let description = [{
Performs for each encrypted indice a lookup on a table of clear integers. Multiple lookup tables are passed, and the application of lookup tables
is performed following the broadcasting rules. The precise lookup is specified by a map.
```mlir
// The result of this operation, is a tensor that contains the result of the lookup on different tables.
// i.e. %res[i, ..., k] = %luts[ %map[i, ..., k] ][ %t[i, ..., k] ]
%res = HLFHELinalg.apply_mapped_lookup_table(%t, %luts, %map): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<DM x ^$p>, tensor<DNx...xD1xindex> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
```
Examples:
```mlir
// Returns the lookup of 3x2 matrix of encrypted indices of width 2 on a vector of 2 tables of size 4=2^2 of clear integers.
//
// [0,1] [0, 1] = [1,2]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
// [2,3] [0, 1] = [5,6]
"HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : (tensor<3x2x!HLFHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!HLFHE.eint<3>>
```
Others examples:
// [0,1] [1, 0] = [3,2]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
// [2,3] [1, 0] = [4,7]
// [0,1] [0, 0] = [1,3]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [1, 1] = [6,0]
// [2,3] [1, 0] = [4,7]
// [0,1] [0] = [1,3]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [1] = [6,0]
// [2,3] [0] = [5,7]
// [0,1] = [1,2]
// [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0]
// [2,3] = [5,6]
}];
let arguments = (ins
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$t,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$luts,
Type<And<[TensorOf<[Index]>.predicate, HasStaticShapePred]>>:$map
);
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let verifier = [{
return ::mlir::zamalang::HLFHELinalg::verifyApplyMappedLookupTable(*this);
}];
}
// Dot product
def Dot : HLFHELinalg_Op<"dot_eint_int"> {
let summary = "Returns the encrypted dot product between a vector of encrypted integers and a vector of clean integers.";

View File

@@ -136,6 +136,15 @@ public:
llvm::ArrayRef<typename ScalarArgumentT::value_type> value)
: TensorLambdaArgument(value, {(int64_t)value.size()}) {}
template <std::size_t size1, std::size_t size2>
TensorLambdaArgument(
typename ScalarArgumentT::value_type (&a)[size1][size2]) {
dimensions = {size1, size2};
auto value = llvm::MutableArrayRef<typename ScalarArgumentT::value_type>(
(typename ScalarArgumentT::value_type *)a, size1 * size2);
std::copy(value.begin(), value.end(), std::back_inserter(this->value));
}
const std::vector<int64_t> &getDimensions() const { return this->dimensions; }
// Returns the total number of elements in the tensor. If the number