feat(compiler): fusing table lookups

This commit is contained in:
Umut
2024-02-12 16:51:16 +03:00
parent efc9314d25
commit 29503dfc17
6 changed files with 829 additions and 0 deletions

View File

@@ -422,7 +422,9 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure, ConstantNoi
let arguments = (ins FHE_AnyEncryptedInteger:$a,
TensorOf<[AnyInteger]>:$lut);
let results = (outs FHE_AnyEncryptedInteger);
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def FHE_RoundEintOp: FHE_Op<"round", [Pure, UnaryEint, DeclareOpInterfaceMethods<UnaryEint, ["sqMANP"]>]> {

View File

@@ -466,6 +466,7 @@ def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", [Pure,
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def FHELinalg_ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_table", [Pure, ConstantNoise]> {
@@ -567,6 +568,7 @@ def FHELinalg_ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_t
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;
let hasCanonicalizer = 1;
}
def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods<Binary, ["sqMANP"]>]> {

View File

@@ -394,6 +394,207 @@ void MulEintIntOp::getCanonicalizationPatterns(
patterns.add<ZeroEncOpPattern>(context);
}
void ApplyLookupTableEintOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
class AfterTluPattern
: public mlir::OpRewritePattern<ApplyLookupTableEintOp> {
public:
AfterTluPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<ApplyLookupTableEintOp>(context, 0) {}
mlir::LogicalResult
matchAndRewrite(ApplyLookupTableEintOp currentOperation,
mlir::PatternRewriter &rewriter) const override {
auto intermediateValue = currentOperation.getA();
auto intermediateOperation =
llvm::dyn_cast_or_null<ApplyLookupTableEintOp>(
intermediateValue.getDefiningOp());
if (!intermediateOperation) {
return mlir::failure();
}
auto intermediateTableValue = intermediateOperation.getLut();
auto intermediateTableOperation =
llvm::dyn_cast_or_null<arith::ConstantOp>(
intermediateTableValue.getDefiningOp());
auto currentTableValue = currentOperation.getLut();
auto currentTableOperation = llvm::dyn_cast_or_null<arith::ConstantOp>(
currentTableValue.getDefiningOp());
if (!intermediateTableOperation || !currentTableOperation) {
return mlir::success();
}
auto intermediateTableContentAttr =
(intermediateTableOperation.getValueAttr()
.dyn_cast_or_null<mlir::DenseIntElementsAttr>());
auto currentTableContentAttr =
(currentTableOperation.getValueAttr()
.dyn_cast_or_null<mlir::DenseIntElementsAttr>());
if (!intermediateTableContentAttr || !currentTableContentAttr) {
return mlir::failure();
}
auto intermediateTableContent =
(intermediateTableContentAttr.getValues<int64_t>());
auto currentTableContent = (currentTableContentAttr.getValues<int64_t>());
auto inputValue = intermediateOperation.getA();
auto inputType = inputValue.getType().dyn_cast<FheIntegerInterface>();
auto inputBitWidth = inputType.getWidth();
auto inputIsSigned = inputType.isSigned();
auto intermediateType =
(intermediateValue.getType().dyn_cast<FheIntegerInterface>());
auto intermediateBitWidth = intermediateType.getWidth();
auto intermediateIsSigned = intermediateType.isSigned();
auto usersOfPreviousOperation = intermediateOperation->getUsers();
auto numberOfUsersOfPreviousOperation = std::distance(
usersOfPreviousOperation.begin(), usersOfPreviousOperation.end());
if (numberOfUsersOfPreviousOperation > 1) {
// This is a special case.
//
// Imagine you have this structure:
// -----------------
// x: uint6
// y: uint3 = tlu[x]
// z: uint3 = y + 1
// a: uint3 = tlu[y]
// b: uint3 = a + z
// -----------------
//
// In this case, it's be better not to fuse `a = tlu[tlu[x]]`.
//
// The reason is that intermediate `y` is necessary for `z`,
// so it has to be computed anyway.
//
// So to calculate `a`, there are 2 options:
// - fused tlu on x
// - regular tlu on y
//
// So for such cases, it's only better to fuse if the
// bit width of `x` is smaller than the bit width of `y`.
auto shouldFuse = inputBitWidth < intermediateBitWidth;
if (!shouldFuse) {
return mlir::failure();
}
}
auto intermediateTableSize = 1 << inputBitWidth;
auto currentTableSize = 1 << intermediateBitWidth;
auto newTableContent = std::vector<int64_t>();
newTableContent.reserve(intermediateTableSize);
auto lookup = [&](ssize_t index) {
if (index < 0) {
index += intermediateTableSize;
}
auto resultOfFirstLookup = intermediateTableContent[index];
// If the result of the first lookup is negative
if (resultOfFirstLookup < 0) {
// We first add the table size to preserve semantics
// e.g., table[-1] == last element in the table == tableSize + (-1)
// e.g., table[-2] == one element before that == tableSize + (-2)
resultOfFirstLookup += currentTableSize;
// If it's still negative
if (resultOfFirstLookup < 0) {
// e.g., imagine first table resulted in -100_000
// (which can exist in tables...)
// then we set it to the smalles possible value
// of the input to the table
// So if -100 is encountered on a signed 7-bit tlu
// corresponding value will be calculated as if -64 is looked up
// [0, 1, 2, 3, -4, -3, -2, -1]
// ^^ smallest value will always be in the middle
resultOfFirstLookup = currentTableSize / 2;
}
} else if (resultOfFirstLookup >= currentTableSize) {
// Another special case is the result of the first table
// being bigger than the table itself
// In this case we approach the value as the
// biggest possible value of the input to the table
if (!intermediateIsSigned) {
// So if 100 is encountered on a unsigned 6-bit tlu
// corresponding value will be calculated as if 63 is looked up
// [0, 1, 2, 3, 4, 5, 6, 7]
// ^ biggest value will always be in the end
resultOfFirstLookup = currentTableSize - 1;
} else {
// So if 100 is encountered on a signed 7-bit tlu
// corresponding value will be calculated as if 63 is looked up
// [0, 1, 2, 3, -4, -3, -2, -1]
// ^ biggest value will always be in one before the middle
resultOfFirstLookup = (currentTableSize / 2) - 1;
}
}
auto resultOfSecondLookup = currentTableContent[resultOfFirstLookup];
return resultOfSecondLookup;
};
if (!inputIsSigned) {
// unsigned lookup table structure
// [0, 1, 2, 3, 4, 5, 6, 7]
// is the identity table
// for the whole table
for (ssize_t x = 0; x < intermediateTableSize; x++) {
newTableContent.push_back(lookup(x));
}
} else {
// signed lookup table structure
// [0, 1, 2, 3, -4, -3, -2, -1]
// is the identity table
// for the positive part
for (ssize_t x = 0; x < intermediateTableSize / 2; x++) {
newTableContent.push_back(lookup(x));
}
// for the negative part
for (ssize_t x = -(intermediateTableSize / 2); x < 0; x++) {
newTableContent.push_back(lookup(x));
}
}
auto newTable = rewriter.create<arith::ConstantOp>(
currentOperation.getLoc(),
DenseIntElementsAttr::get(intermediateTableValue.getType(),
newTableContent));
auto newOperation = rewriter.create<ApplyLookupTableEintOp>(
currentOperation.getLoc(), currentOperation.getType(), inputValue,
newTable);
rewriter.replaceAllUsesWith(currentOperation, newOperation);
return mlir::success();
}
};
patterns.add<AfterTluPattern>(context);
}
template <typename SignedConvOp>
void getSignedConvCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
mlir::MLIRContext *context) {

View File

@@ -1687,6 +1687,450 @@ void ToUnsignedOp::getCanonicalizationPatterns(
getSignedConvCanonicalizationPatterns<ToUnsignedOp>(patterns, context);
}
std::optional<mlir::Value>
fuseBackToBackTableLookups(mlir::Operation *currentOperation,
mlir::PatternRewriter &rewriter) {
using mlir::concretelang::FHE::FheIntegerInterface;
auto currentOperationAsTlu =
llvm::dyn_cast<ApplyLookupTableEintOp>(currentOperation);
auto currentOperationAsMappedTlu =
llvm::dyn_cast<ApplyMappedLookupTableEintOp>(currentOperation);
if (!currentOperationAsTlu && !currentOperationAsMappedTlu) {
return std::nullopt;
}
auto intermediateValue =
(currentOperationAsTlu ? (currentOperationAsTlu.getT())
: (currentOperationAsMappedTlu.getT()));
auto intermediateOperation = intermediateValue.getDefiningOp();
if (!intermediateOperation) {
return std::nullopt;
}
auto intermediateOperationAsTlu =
llvm::dyn_cast<ApplyLookupTableEintOp>(intermediateOperation);
auto intermediateOperationAsMappedTlu =
llvm::dyn_cast<ApplyMappedLookupTableEintOp>(intermediateOperation);
if (!intermediateOperationAsTlu && !intermediateOperationAsMappedTlu) {
return std::nullopt;
}
auto inputValue =
(intermediateOperationAsTlu ? (intermediateOperationAsTlu.getT())
: (intermediateOperationAsMappedTlu.getT()));
struct Indexer {
int64_t tableSize;
bool isSigned;
Indexer(int64_t tableSize, bool isSigned)
: tableSize{tableSize}, isSigned{isSigned} {}
virtual ~Indexer() = default;
int64_t sanitizeIndex(int64_t index) const {
// Same logic as the lookup lambda in
// cannonicalization of FHE.apply_lookup_table.
// See FHEOps.cpp for explanation of the following code.
if (index < 0) {
index += tableSize;
if (index < 0) {
index = tableSize / 2;
}
} else if (index >= tableSize) {
if (!isSigned) {
index = tableSize - 1;
} else {
index = (tableSize / 2) - 1;
}
}
return index;
}
virtual int64_t get(int64_t index, int64_t position) const = 0;
};
struct TluIdexer : public Indexer {
std::vector<int64_t> tableContent;
TluIdexer(int64_t tableSize, bool isSigned,
std::vector<int64_t> tableContent)
:
Indexer{tableSize, isSigned}, tableContent{std::move(tableContent)} {}
~TluIdexer() override = default;
static std::optional<std::unique_ptr<TluIdexer>>
create(ApplyLookupTableEintOp operation) {
auto tableValue = operation.getLut();
auto tableOperation =
llvm::dyn_cast_or_null<arith::ConstantOp>(tableValue.getDefiningOp());
if (!tableOperation) {
return std::nullopt;
}
auto tableContentAttr =
tableOperation.getValueAttr()
.dyn_cast_or_null<mlir::DenseIntElementsAttr>();
if (!tableContentAttr) {
return std::nullopt;
}
auto tableContent = std::vector<int64_t>();
for (auto value : tableContentAttr.getValues<int64_t>()) {
tableContent.push_back(value);
}
auto inputValue = operation.getT();
auto inputType = inputValue.getType()
.cast<RankedTensorType>()
.getElementType()
.dyn_cast<FheIntegerInterface>();
auto tableSize = 1 << inputType.getWidth();
auto isSigned = inputType.isSigned();
return std::make_unique<TluIdexer>(
TluIdexer(tableSize, isSigned, tableContent));
};
int64_t get(int64_t index, int64_t position) const override {
return tableContent[sanitizeIndex(index)];
}
};
struct MappedTluIdexer : public Indexer {
std::vector<int64_t> tablesContent;
std::vector<int64_t> mapContent;
MappedTluIdexer(int64_t tableSize, bool isSigned,
std::vector<int64_t> tablesContent,
std::vector<int64_t> mapContent)
:
Indexer{tableSize, isSigned}, tablesContent{std::move(tablesContent)},
mapContent{std::move(mapContent)} {}
~MappedTluIdexer() override = default;
static std::optional<std::unique_ptr<MappedTluIdexer>>
create(ApplyMappedLookupTableEintOp operation) {
auto tablesValue = operation.getLuts();
auto tablesOperation = llvm::dyn_cast_or_null<arith::ConstantOp>(
tablesValue.getDefiningOp());
if (!tablesOperation) {
return std::nullopt;
}
auto tablesContentAttr =
tablesOperation.getValueAttr()
.dyn_cast_or_null<mlir::DenseIntElementsAttr>();
if (!tablesContentAttr) {
return std::nullopt;
}
auto tablesContent = std::vector<int64_t>();
for (auto value : tablesContentAttr.getValues<int64_t>()) {
tablesContent.push_back(value);
}
auto mapValue = operation.getMap();
auto mapOperation =
llvm::dyn_cast_or_null<arith::ConstantOp>(mapValue.getDefiningOp());
if (!mapOperation) {
return std::nullopt;
}
auto mapContentAttr = mapOperation.getValueAttr()
.dyn_cast_or_null<mlir::DenseIntElementsAttr>();
if (!mapContentAttr) {
return std::nullopt;
}
auto mapContent = std::vector<int64_t>();
for (auto value : mapContentAttr.getValues<int64_t>()) {
mapContent.push_back(value);
}
auto inputValue = operation.getT();
auto inputType = inputValue.getType()
.cast<RankedTensorType>()
.getElementType()
.dyn_cast<FheIntegerInterface>();
auto tableSize = 1 << inputType.getWidth();
auto isSigned = inputType.isSigned();
return std::make_unique<MappedTluIdexer>(
MappedTluIdexer(tableSize, isSigned, tablesContent, mapContent));
}
int64_t get(int64_t index, int64_t position) const override {
int64_t tableIndex = mapContent[position];
return tablesContent[sanitizeIndex(index) + (tableIndex * tableSize)];
}
};
std::unique_ptr<Indexer> intermediateIndexer;
if (intermediateOperationAsTlu) {
auto indexer = TluIdexer::create(intermediateOperationAsTlu);
if (!indexer) {
return std::nullopt;
}
intermediateIndexer = std::move(*indexer);
} else {
auto indexer = MappedTluIdexer::create(intermediateOperationAsMappedTlu);
if (!indexer) {
return std::nullopt;
}
intermediateIndexer = std::move(*indexer);
}
std::unique_ptr<Indexer> currentIndexer;
if (currentOperationAsTlu) {
auto indexer = TluIdexer::create(currentOperationAsTlu);
if (!indexer) {
return std::nullopt;
}
currentIndexer = std::move(*indexer);
} else {
auto indexer = MappedTluIdexer::create(currentOperationAsMappedTlu);
if (!indexer) {
return std::nullopt;
}
currentIndexer = std::move(*indexer);
}
auto usersOfPreviousOperation = intermediateOperation->getUsers();
auto numberOfUsersOfPreviousOperation = std::distance(
usersOfPreviousOperation.begin(), usersOfPreviousOperation.end());
if (numberOfUsersOfPreviousOperation > 1) {
// This is a special case.
//
// Imagine you have this structure:
// -----------------
// x: uint6
// y: uint3 = tlu[x]
// z: uint3 = y + 1
// a: uint3 = tlu[y]
// b: uint3 = a + z
// -----------------
//
// In this case, it might be better not to fuse `a = tlu[tlu[x]]`.
//
// The reason is, intermediate `y` is necessary for `z`,
// so it have to be computed anyway.
//
// So to calculate `a`, there are 2 options:
// - fused tlu on x
// - regular tlu on y
//
// In this case, it's best to fuse only if
// bit width of `x` is smaller than bit width of `y`.
// We can use the table size as it's derived from the bit width
// and it preserves the ordering.
auto xTableSize = intermediateIndexer->tableSize;
auto yTableSize = currentIndexer->tableSize;
auto shouldFuse = xTableSize < yTableSize;
if (!shouldFuse) {
return std::nullopt;
}
}
auto resultingType =
(currentOperationAsTlu ? (currentOperationAsTlu.getType())
: (currentOperationAsMappedTlu.getType()));
if (intermediateOperationAsTlu && currentOperationAsTlu) {
auto newTableContent = std::vector<int64_t>();
newTableContent.reserve(intermediateIndexer->tableSize);
if (!intermediateIndexer->isSigned) {
for (ssize_t x = 0; x < intermediateIndexer->tableSize; x++) {
auto resultOfFirstTableLookup = intermediateIndexer->get(x, 0);
newTableContent.push_back(
currentIndexer->get(resultOfFirstTableLookup, 0));
}
} else {
for (ssize_t x = 0; x < intermediateIndexer->tableSize / 2; x++) {
auto resultOfFirstTableLookup = intermediateIndexer->get(x, 0);
newTableContent.push_back(
currentIndexer->get(resultOfFirstTableLookup, 0));
}
for (ssize_t x = -(intermediateIndexer->tableSize / 2); x < 0; x++) {
auto resultOfFirstTableLookup = intermediateIndexer->get(x, 0);
newTableContent.push_back(
currentIndexer->get(resultOfFirstTableLookup, 0));
}
}
auto newTableShape = std::vector<int64_t>{intermediateIndexer->tableSize};
auto newTableType = RankedTensorType::get(
newTableShape, IntegerType::get(currentOperation->getContext(), 64));
auto newTable = rewriter.create<arith::ConstantOp>(
currentOperation->getLoc(),
DenseIntElementsAttr::get(newTableType, newTableContent));
auto newOperation = rewriter.create<ApplyLookupTableEintOp>(
currentOperation->getLoc(), resultingType, inputValue, newTable);
return newOperation;
}
auto newTableContents = std::vector<std::vector<int64_t>>();
auto newMapContent = std::vector<int64_t>();
auto inputShape = inputValue.getType().cast<RankedTensorType>().getShape();
int64_t numberOfInputs = 1;
for (auto dimension : inputShape) {
numberOfInputs *= dimension;
}
for (int64_t position = 0; position < numberOfInputs; position++) {
auto newTableContent = std::vector<int64_t>();
newTableContent.reserve(intermediateIndexer->tableSize);
if (!intermediateIndexer->isSigned) {
for (ssize_t x = 0; x < intermediateIndexer->tableSize; x++) {
auto resultOfFirstTableLookup = intermediateIndexer->get(x, position);
newTableContent.push_back(
currentIndexer->get(resultOfFirstTableLookup, position));
}
} else {
for (ssize_t x = 0; x < intermediateIndexer->tableSize / 2; x++) {
auto resultOfFirstTableLookup = intermediateIndexer->get(x, position);
newTableContent.push_back(
currentIndexer->get(resultOfFirstTableLookup, position));
}
for (ssize_t x = -(intermediateIndexer->tableSize / 2); x < 0; x++) {
auto resultOfFirstTableLookup = intermediateIndexer->get(x, position);
newTableContent.push_back(
currentIndexer->get(resultOfFirstTableLookup, position));
}
}
auto search = std::find(newTableContents.begin(), newTableContents.end(),
newTableContent);
size_t index;
if (search == newTableContents.end()) {
index = newTableContents.size();
newTableContents.push_back(newTableContent);
} else {
index = std::distance(newTableContents.begin(), search);
}
newMapContent.push_back(index);
}
if (newTableContents.size() == 1) {
auto newTableShape = std::vector<int64_t>{intermediateIndexer->tableSize};
auto newTableType = RankedTensorType::get(
newTableShape, IntegerType::get(currentOperation->getContext(), 64));
auto newTable = rewriter.create<arith::ConstantOp>(
currentOperation->getLoc(),
DenseIntElementsAttr::get(newTableType, newTableContents[0]));
auto newOperation = rewriter.create<ApplyLookupTableEintOp>(
currentOperation->getLoc(), resultingType, inputValue, newTable);
return newOperation;
} else {
auto newTablesShape =
std::vector<int64_t>{static_cast<int64_t>(newTableContents.size()),
intermediateIndexer->tableSize};
auto newTablesType = RankedTensorType::get(
newTablesShape, IntegerType::get(currentOperation->getContext(), 64));
auto newTableContentsFlattened = std::vector<int64_t>();
for (auto newTableContent : newTableContents) {
newTableContentsFlattened.insert(newTableContentsFlattened.end(),
newTableContent.begin(),
newTableContent.end());
}
auto newTables = rewriter.create<arith::ConstantOp>(
currentOperation->getLoc(),
DenseIntElementsAttr::get(newTablesType, newTableContentsFlattened));
auto newMapShape = inputShape;
auto newMapType = RankedTensorType::get(
newMapShape, IndexType::get(currentOperation->getContext()));
auto newMap = rewriter.create<arith::ConstantOp>(
currentOperation->getLoc(),
DenseIntElementsAttr::get(newMapType, newMapContent));
auto newOperation = rewriter.create<ApplyMappedLookupTableEintOp>(
currentOperation->getLoc(), resultingType, inputValue, newTables,
newMap);
return newOperation;
}
return std::nullopt;
}
void ApplyLookupTableEintOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
class AfterTluPattern
: public mlir::OpRewritePattern<ApplyLookupTableEintOp> {
public:
AfterTluPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<ApplyLookupTableEintOp>(context, 0) {}
mlir::LogicalResult
matchAndRewrite(ApplyLookupTableEintOp currentOperation,
mlir::PatternRewriter &rewriter) const override {
auto replacement = fuseBackToBackTableLookups(currentOperation, rewriter);
if (replacement) {
rewriter.replaceAllUsesWith(currentOperation, *replacement);
return mlir::success();
}
return mlir::failure();
}
};
patterns.add<AfterTluPattern>(context);
}
void ApplyMappedLookupTableEintOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
class AfterTluPattern
: public mlir::OpRewritePattern<ApplyMappedLookupTableEintOp> {
public:
AfterTluPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<ApplyMappedLookupTableEintOp>(context, 0) {}
mlir::LogicalResult
matchAndRewrite(ApplyMappedLookupTableEintOp currentOperation,
mlir::PatternRewriter &rewriter) const override {
auto replacement = fuseBackToBackTableLookups(currentOperation, rewriter);
if (replacement) {
rewriter.replaceAllUsesWith(currentOperation, *replacement);
return mlir::success();
}
return mlir::failure();
}
};
patterns.add<AfterTluPattern>(context);
}
} // namespace FHELinalg
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,76 @@
// RUN: concretecompiler --split-input-file --action=dump-fhe --passes canonicalize %s 2>&1| FileCheck %s
// -----
// CHECK: func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<5> {
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 2, 4, 8, 12, 18, 24]> : tensor<8xi64>
// CHECK-NEXT: %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<5>
// CHECK-NEXT: return %0 : !FHE.eint<5>
// CHECK-NEXT: }
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<5> {
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 1, 4, 9, 16, 25, 36, 49]> : tensor<8xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<6>
%cst_0 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31]> : tensor<64xi64>
%1 = "FHE.apply_lookup_table"(%0, %cst_0) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<5>
return %1 : !FHE.eint<5>
}
// -----
// CHECK: func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<4> {
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 2, 4]> : tensor<4xi64>
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 1, 4, 9]> : tensor<4xi64>
// CHECK-NEXT: %0 = "FHE.apply_lookup_table"(%arg0, %cst_0) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<4>
// CHECK-NEXT: %1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<4>
// CHECK-NEXT: %2 = "FHE.add_eint"(%1, %0) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
// CHECK-NEXT: return %2 : !FHE.eint<4>
// CHECK-NEXT: }
func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<4> {
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 1, 4, 9]> : tensor<4xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<4>
%cst_0 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]> : tensor<16xi64>
%1 = "FHE.apply_lookup_table"(%0, %cst_0) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
%2 = "FHE.add_eint"(%1, %0) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
return %2 : !FHE.eint<4>
}
// -----
// CHECK: func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]> : tensor<16xi64>
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8]> : tensor<64xi64>
// CHECK-NEXT: %0 = "FHE.apply_lookup_table"(%arg0, %cst_0) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<4>
// CHECK-NEXT: %1 = "FHE.apply_lookup_table"(%0, %cst) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
// CHECK-NEXT: %2 = "FHE.add_eint"(%0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
// CHECK-NEXT: %3 = "FHE.add_eint"(%1, %2) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
// CHECK-NEXT: return %3 : !FHE.eint<4>
// CHECK-NEXT: }
func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
%cst = arith.constant dense<[0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8]> : tensor<64xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<4>
%c2_i3 = arith.constant 2 : i3
%cst_0 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]> : tensor<16xi64>
%1 = "FHE.apply_lookup_table"(%0, %cst_0) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
%2 = "FHE.add_eint"(%0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
%3 = "FHE.add_eint"(%1, %2) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
return %3 : !FHE.eint<4>
}
// -----
// CHECK: func.func @main(%arg0: !FHE.esint<5>) -> !FHE.esint<8> {
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 1, 1, 8, 8, 27, 27, 64, 64, 125, 125, 216, 216, 343, 343, -512, -512, -343, -343, -216, -216, -125, -125, -64, -64, -27, -27, -8, -8, -1, -1]> : tensor<32xi64>
// CHECK-NEXT: %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.esint<8>
// CHECK-NEXT: return %0 : !FHE.esint<8>
// CHECK-NEXT: }
func.func @main(%arg0: !FHE.esint<5>) -> !FHE.esint<8> {
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, -8, -8, -7, -7, -6, -6, -5, -5, -4, -4, -3, -3, -2, -2, -1, -1]> : tensor<32xi64>
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.esint<4>
%c3_i3 = arith.constant 3 : i3
%cst_0 = arith.constant dense<[0, 1, 8, 27, 64, 125, 216, 343, -512, -343, -216, -125, -64, -27, -8, -1]> : tensor<16xi64>
%1 = "FHE.apply_lookup_table"(%0, %cst_0) : (!FHE.esint<4>, tensor<16xi64>) -> !FHE.esint<8>
return %1 : !FHE.esint<8>
}

View File

@@ -0,0 +1,104 @@
// RUN: concretecompiler --split-input-file --action=dump-fhe --passes canonicalize %s 2>&1| FileCheck %s
// -----
// CHECK: func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<3>> {
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 2, 4]> : tensor<4xi64>
// CHECK-NEXT: %0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<4xi64>) -> tensor<2x!FHE.eint<3>>
// CHECK-NEXT: return %0 : tensor<2x!FHE.eint<3>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<3>> {
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 1, 4, 9]> : tensor<4xi64>
%0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<4xi64>) -> tensor<2x!FHE.eint<4>>
%cst_0 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]> : tensor<16xi64>
%1 = "FHELinalg.apply_lookup_table"(%0, %cst_0) : (tensor<2x!FHE.eint<4>>, tensor<16xi64>) -> tensor<2x!FHE.eint<3>>
return %1 : tensor<2x!FHE.eint<3>>
}
// -----
// CHECK: func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<4>> {
// CHECK-NEXT: %cst = arith.constant dense<{{\[\[0, 0, 2, 4\], \[0, 0, 4, 13\]\]}}> : tensor<2x4xi64>
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 1]> : tensor<2xindex>
// CHECK-NEXT: %0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst, %cst_0) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<4>>
// CHECK-NEXT: return %0 : tensor<2x!FHE.eint<4>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<4>> {
%cst = arith.constant dense<[0, 1]> : tensor<2xindex>
%cst_0 = arith.constant dense<[[0, 1, 4, 9], [0, 1, 8, 27]]> : tensor<2x4xi64>
%0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst_0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<5>>
%c2_i3 = arith.constant 2 : i3
%cst_1 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15]> : tensor<32xi64>
%1 = "FHELinalg.apply_lookup_table"(%0, %cst_1) : (tensor<2x!FHE.eint<5>>, tensor<32xi64>) -> tensor<2x!FHE.eint<4>>
return %1 : tensor<2x!FHE.eint<4>>
}
// -----
// CHECK: func.func @main(%arg0: tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<4>> {
// CHECK-NEXT: %cst = arith.constant dense<{{\[\[0, 0, 2, 4\], \[0, 0, 4, 13\]\]}}> : tensor<2x4xi64>
// CHECK-NEXT: %cst_0 = arith.constant dense<{{\[\[0, 1\], \[1, 0\], \[0, 1\]\]}}> : tensor<3x2xindex>
// CHECK-NEXT: %0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst, %cst_0) : (tensor<3x2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!FHE.eint<4>>
// CHECK-NEXT: return %0 : tensor<3x2x!FHE.eint<4>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<4>> {
%cst = arith.constant dense<[[0, 1], [1, 0], [0, 1]]> : tensor<3x2xindex>
%cst_0 = arith.constant dense<[[0, 1, 4, 9], [0, 1, 8, 27]]> : tensor<2x4xi64>
%0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst_0, %cst) : (tensor<3x2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!FHE.eint<5>>
%c2_i3 = arith.constant 2 : i3
%cst_1 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15]> : tensor<32xi64>
%1 = "FHELinalg.apply_lookup_table"(%0, %cst_1) : (tensor<3x2x!FHE.eint<5>>, tensor<32xi64>) -> tensor<3x2x!FHE.eint<4>>
return %1 : tensor<3x2x!FHE.eint<4>>
}
// -----
// CHECK: func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<3>> {
// CHECK-NEXT: %cst = arith.constant dense<{{\[\[0, 0, 2, 4\], \[0, 0, 1, 3\]\]}}> : tensor<2x4xi64>
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 1]> : tensor<2xindex>
// CHECK-NEXT: %0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst, %cst_0) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<3>>
// CHECK-NEXT: return %0 : tensor<2x!FHE.eint<3>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<3>> {
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 1, 4, 9]> : tensor<4xi64>
%0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<4xi64>) -> tensor<2x!FHE.eint<4>>
%cst_0 = arith.constant dense<[0, 1]> : tensor<2xindex>
%cst_1 = arith.constant dense<[[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5]]> : tensor<2x16xi64>
%1 = "FHELinalg.apply_mapped_lookup_table"(%0, %cst_1, %cst_0) : (tensor<2x!FHE.eint<4>>, tensor<2x16xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<3>>
return %1 : tensor<2x!FHE.eint<3>>
}
// -----
// CHECK: func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<4>> {
// CHECK-NEXT: %cst = arith.constant dense<{{\[\[0, 0, 1, 3\], \[0, 0, 4, 13\]\]}}> : tensor<2x4xi64>
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 1]> : tensor<2xindex>
// CHECK-NEXT: %0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst, %cst_0) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<4>>
// CHECK-NEXT: return %0 : tensor<2x!FHE.eint<4>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<4>> {
%cst = arith.constant dense<[0, 1]> : tensor<2xindex>
%cst_0 = arith.constant dense<[[0, 1, 4, 9], [0, 1, 8, 27]]> : tensor<2x4xi64>
%0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst_0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<5>>
%cst_1 = arith.constant dense<[[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15]]> : tensor<2x32xi64>
%1 = "FHELinalg.apply_mapped_lookup_table"(%0, %cst_1, %cst) : (tensor<2x!FHE.eint<5>>, tensor<2x32xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<4>>
return %1 : tensor<2x!FHE.eint<4>>
}
// -----
// CHECK: func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<10>> {
// CHECK-NEXT: %cst = arith.constant dense<[0, 1, 64, 729]> : tensor<4xi64>
// CHECK-NEXT: %0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<4xi64>) -> tensor<2x!FHE.eint<10>>
// CHECK-NEXT: return %0 : tensor<2x!FHE.eint<10>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<10>> {
%cst = arith.constant dense<[0, 1]> : tensor<2xindex>
%cst_0 = arith.constant dense<[[0, 1, 4, 9], [0, 1, 8, 27]]> : tensor<2x4xi64>
%0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst_0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<5>>
%cst_1 = arith.constant dense<[[0, 1, 8, 27, 64, 125, 216, 343, 512, 729, 1000, 1331, 1728, 2197, 2744, 3375, 4096, 4913, 5832, 6859, 8000, 9261, 10648, 12167, 13824, 15625, 17576, 19683, 21952, 24389, 27000, 29791], [0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625, 676, 729, 784, 841, 900, 961]]> : tensor<2x32xi64>
%1 = "FHELinalg.apply_mapped_lookup_table"(%0, %cst_1, %cst) : (tensor<2x!FHE.eint<5>>, tensor<2x32xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<10>>
return %1 : tensor<2x!FHE.eint<10>>
}