mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler): fusing table lookups
This commit is contained in:
@@ -422,7 +422,9 @@ def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [Pure, ConstantNoi
|
||||
let arguments = (ins FHE_AnyEncryptedInteger:$a,
|
||||
TensorOf<[AnyInteger]>:$lut);
|
||||
let results = (outs FHE_AnyEncryptedInteger);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def FHE_RoundEintOp: FHE_Op<"round", [Pure, UnaryEint, DeclareOpInterfaceMethods<UnaryEint, ["sqMANP"]>]> {
|
||||
|
||||
@@ -466,6 +466,7 @@ def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", [Pure,
|
||||
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def FHELinalg_ApplyMultiLookupTableEintOp : FHELinalg_Op<"apply_multi_lookup_table", [Pure, ConstantNoise]> {
|
||||
@@ -567,6 +568,7 @@ def FHELinalg_ApplyMappedLookupTableEintOp : FHELinalg_Op<"apply_mapped_lookup_t
|
||||
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int", [Pure, BinaryEintInt, DeclareOpInterfaceMethods<Binary, ["sqMANP"]>]> {
|
||||
|
||||
@@ -394,6 +394,207 @@ void MulEintIntOp::getCanonicalizationPatterns(
|
||||
patterns.add<ZeroEncOpPattern>(context);
|
||||
}
|
||||
|
||||
void ApplyLookupTableEintOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
|
||||
class AfterTluPattern
|
||||
: public mlir::OpRewritePattern<ApplyLookupTableEintOp> {
|
||||
public:
|
||||
AfterTluPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<ApplyLookupTableEintOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(ApplyLookupTableEintOp currentOperation,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
auto intermediateValue = currentOperation.getA();
|
||||
auto intermediateOperation =
|
||||
llvm::dyn_cast_or_null<ApplyLookupTableEintOp>(
|
||||
intermediateValue.getDefiningOp());
|
||||
|
||||
if (!intermediateOperation) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto intermediateTableValue = intermediateOperation.getLut();
|
||||
auto intermediateTableOperation =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(
|
||||
intermediateTableValue.getDefiningOp());
|
||||
|
||||
auto currentTableValue = currentOperation.getLut();
|
||||
auto currentTableOperation = llvm::dyn_cast_or_null<arith::ConstantOp>(
|
||||
currentTableValue.getDefiningOp());
|
||||
|
||||
if (!intermediateTableOperation || !currentTableOperation) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
auto intermediateTableContentAttr =
|
||||
(intermediateTableOperation.getValueAttr()
|
||||
.dyn_cast_or_null<mlir::DenseIntElementsAttr>());
|
||||
auto currentTableContentAttr =
|
||||
(currentTableOperation.getValueAttr()
|
||||
.dyn_cast_or_null<mlir::DenseIntElementsAttr>());
|
||||
|
||||
if (!intermediateTableContentAttr || !currentTableContentAttr) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto intermediateTableContent =
|
||||
(intermediateTableContentAttr.getValues<int64_t>());
|
||||
auto currentTableContent = (currentTableContentAttr.getValues<int64_t>());
|
||||
|
||||
auto inputValue = intermediateOperation.getA();
|
||||
auto inputType = inputValue.getType().dyn_cast<FheIntegerInterface>();
|
||||
auto inputBitWidth = inputType.getWidth();
|
||||
auto inputIsSigned = inputType.isSigned();
|
||||
|
||||
auto intermediateType =
|
||||
(intermediateValue.getType().dyn_cast<FheIntegerInterface>());
|
||||
auto intermediateBitWidth = intermediateType.getWidth();
|
||||
auto intermediateIsSigned = intermediateType.isSigned();
|
||||
|
||||
auto usersOfPreviousOperation = intermediateOperation->getUsers();
|
||||
auto numberOfUsersOfPreviousOperation = std::distance(
|
||||
usersOfPreviousOperation.begin(), usersOfPreviousOperation.end());
|
||||
|
||||
if (numberOfUsersOfPreviousOperation > 1) {
|
||||
// This is a special case.
|
||||
//
|
||||
// Imagine you have this structure:
|
||||
// -----------------
|
||||
// x: uint6
|
||||
// y: uint3 = tlu[x]
|
||||
// z: uint3 = y + 1
|
||||
// a: uint3 = tlu[y]
|
||||
// b: uint3 = a + z
|
||||
// -----------------
|
||||
//
|
||||
// In this case, it's be better not to fuse `a = tlu[tlu[x]]`.
|
||||
//
|
||||
// The reason is that intermediate `y` is necessary for `z`,
|
||||
// so it has to be computed anyway.
|
||||
//
|
||||
// So to calculate `a`, there are 2 options:
|
||||
// - fused tlu on x
|
||||
// - regular tlu on y
|
||||
//
|
||||
// So for such cases, it's only better to fuse if the
|
||||
// bit width of `x` is smaller than the bit width of `y`.
|
||||
|
||||
auto shouldFuse = inputBitWidth < intermediateBitWidth;
|
||||
if (!shouldFuse) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
auto intermediateTableSize = 1 << inputBitWidth;
|
||||
auto currentTableSize = 1 << intermediateBitWidth;
|
||||
|
||||
auto newTableContent = std::vector<int64_t>();
|
||||
newTableContent.reserve(intermediateTableSize);
|
||||
|
||||
auto lookup = [&](ssize_t index) {
|
||||
if (index < 0) {
|
||||
index += intermediateTableSize;
|
||||
}
|
||||
auto resultOfFirstLookup = intermediateTableContent[index];
|
||||
|
||||
// If the result of the first lookup is negative
|
||||
if (resultOfFirstLookup < 0) {
|
||||
// We first add the table size to preserve semantics
|
||||
// e.g., table[-1] == last element in the table == tableSize + (-1)
|
||||
// e.g., table[-2] == one element before that == tableSize + (-2)
|
||||
resultOfFirstLookup += currentTableSize;
|
||||
|
||||
// If it's still negative
|
||||
if (resultOfFirstLookup < 0) {
|
||||
// e.g., imagine first table resulted in -100_000
|
||||
// (which can exist in tables...)
|
||||
// then we set it to the smalles possible value
|
||||
// of the input to the table
|
||||
|
||||
// So if -100 is encountered on a signed 7-bit tlu
|
||||
// corresponding value will be calculated as if -64 is looked up
|
||||
|
||||
// [0, 1, 2, 3, -4, -3, -2, -1]
|
||||
// ^^ smallest value will always be in the middle
|
||||
|
||||
resultOfFirstLookup = currentTableSize / 2;
|
||||
}
|
||||
} else if (resultOfFirstLookup >= currentTableSize) {
|
||||
// Another special case is the result of the first table
|
||||
// being bigger than the table itself
|
||||
|
||||
// In this case we approach the value as the
|
||||
// biggest possible value of the input to the table
|
||||
|
||||
if (!intermediateIsSigned) {
|
||||
|
||||
// So if 100 is encountered on a unsigned 6-bit tlu
|
||||
// corresponding value will be calculated as if 63 is looked up
|
||||
|
||||
// [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
// ^ biggest value will always be in the end
|
||||
|
||||
resultOfFirstLookup = currentTableSize - 1;
|
||||
|
||||
} else {
|
||||
|
||||
// So if 100 is encountered on a signed 7-bit tlu
|
||||
// corresponding value will be calculated as if 63 is looked up
|
||||
|
||||
// [0, 1, 2, 3, -4, -3, -2, -1]
|
||||
// ^ biggest value will always be in one before the middle
|
||||
|
||||
resultOfFirstLookup = (currentTableSize / 2) - 1;
|
||||
}
|
||||
}
|
||||
auto resultOfSecondLookup = currentTableContent[resultOfFirstLookup];
|
||||
|
||||
return resultOfSecondLookup;
|
||||
};
|
||||
|
||||
if (!inputIsSigned) {
|
||||
// unsigned lookup table structure
|
||||
// [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
// is the identity table
|
||||
|
||||
// for the whole table
|
||||
for (ssize_t x = 0; x < intermediateTableSize; x++) {
|
||||
newTableContent.push_back(lookup(x));
|
||||
}
|
||||
} else {
|
||||
// signed lookup table structure
|
||||
// [0, 1, 2, 3, -4, -3, -2, -1]
|
||||
// is the identity table
|
||||
|
||||
// for the positive part
|
||||
for (ssize_t x = 0; x < intermediateTableSize / 2; x++) {
|
||||
newTableContent.push_back(lookup(x));
|
||||
}
|
||||
// for the negative part
|
||||
for (ssize_t x = -(intermediateTableSize / 2); x < 0; x++) {
|
||||
newTableContent.push_back(lookup(x));
|
||||
}
|
||||
}
|
||||
|
||||
auto newTable = rewriter.create<arith::ConstantOp>(
|
||||
currentOperation.getLoc(),
|
||||
DenseIntElementsAttr::get(intermediateTableValue.getType(),
|
||||
newTableContent));
|
||||
|
||||
auto newOperation = rewriter.create<ApplyLookupTableEintOp>(
|
||||
currentOperation.getLoc(), currentOperation.getType(), inputValue,
|
||||
newTable);
|
||||
|
||||
rewriter.replaceAllUsesWith(currentOperation, newOperation);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
patterns.add<AfterTluPattern>(context);
|
||||
}
|
||||
|
||||
template <typename SignedConvOp>
|
||||
void getSignedConvCanonicalizationPatterns(mlir::RewritePatternSet &patterns,
|
||||
mlir::MLIRContext *context) {
|
||||
|
||||
@@ -1687,6 +1687,450 @@ void ToUnsignedOp::getCanonicalizationPatterns(
|
||||
getSignedConvCanonicalizationPatterns<ToUnsignedOp>(patterns, context);
|
||||
}
|
||||
|
||||
std::optional<mlir::Value>
|
||||
fuseBackToBackTableLookups(mlir::Operation *currentOperation,
|
||||
mlir::PatternRewriter &rewriter) {
|
||||
|
||||
using mlir::concretelang::FHE::FheIntegerInterface;
|
||||
|
||||
auto currentOperationAsTlu =
|
||||
llvm::dyn_cast<ApplyLookupTableEintOp>(currentOperation);
|
||||
auto currentOperationAsMappedTlu =
|
||||
llvm::dyn_cast<ApplyMappedLookupTableEintOp>(currentOperation);
|
||||
|
||||
if (!currentOperationAsTlu && !currentOperationAsMappedTlu) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto intermediateValue =
|
||||
(currentOperationAsTlu ? (currentOperationAsTlu.getT())
|
||||
: (currentOperationAsMappedTlu.getT()));
|
||||
|
||||
auto intermediateOperation = intermediateValue.getDefiningOp();
|
||||
if (!intermediateOperation) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto intermediateOperationAsTlu =
|
||||
llvm::dyn_cast<ApplyLookupTableEintOp>(intermediateOperation);
|
||||
auto intermediateOperationAsMappedTlu =
|
||||
llvm::dyn_cast<ApplyMappedLookupTableEintOp>(intermediateOperation);
|
||||
|
||||
if (!intermediateOperationAsTlu && !intermediateOperationAsMappedTlu) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto inputValue =
|
||||
(intermediateOperationAsTlu ? (intermediateOperationAsTlu.getT())
|
||||
: (intermediateOperationAsMappedTlu.getT()));
|
||||
|
||||
struct Indexer {
|
||||
int64_t tableSize;
|
||||
bool isSigned;
|
||||
|
||||
Indexer(int64_t tableSize, bool isSigned)
|
||||
: tableSize{tableSize}, isSigned{isSigned} {}
|
||||
|
||||
virtual ~Indexer() = default;
|
||||
|
||||
int64_t sanitizeIndex(int64_t index) const {
|
||||
// Same logic as the lookup lambda in
|
||||
// cannonicalization of FHE.apply_lookup_table.
|
||||
// See FHEOps.cpp for explanation of the following code.
|
||||
if (index < 0) {
|
||||
index += tableSize;
|
||||
if (index < 0) {
|
||||
index = tableSize / 2;
|
||||
}
|
||||
} else if (index >= tableSize) {
|
||||
if (!isSigned) {
|
||||
index = tableSize - 1;
|
||||
} else {
|
||||
index = (tableSize / 2) - 1;
|
||||
}
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
virtual int64_t get(int64_t index, int64_t position) const = 0;
|
||||
};
|
||||
|
||||
struct TluIdexer : public Indexer {
|
||||
std::vector<int64_t> tableContent;
|
||||
|
||||
TluIdexer(int64_t tableSize, bool isSigned,
|
||||
std::vector<int64_t> tableContent)
|
||||
:
|
||||
|
||||
Indexer{tableSize, isSigned}, tableContent{std::move(tableContent)} {}
|
||||
|
||||
~TluIdexer() override = default;
|
||||
|
||||
static std::optional<std::unique_ptr<TluIdexer>>
|
||||
create(ApplyLookupTableEintOp operation) {
|
||||
auto tableValue = operation.getLut();
|
||||
|
||||
auto tableOperation =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(tableValue.getDefiningOp());
|
||||
if (!tableOperation) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto tableContentAttr =
|
||||
tableOperation.getValueAttr()
|
||||
.dyn_cast_or_null<mlir::DenseIntElementsAttr>();
|
||||
if (!tableContentAttr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto tableContent = std::vector<int64_t>();
|
||||
for (auto value : tableContentAttr.getValues<int64_t>()) {
|
||||
tableContent.push_back(value);
|
||||
}
|
||||
|
||||
auto inputValue = operation.getT();
|
||||
auto inputType = inputValue.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType()
|
||||
.dyn_cast<FheIntegerInterface>();
|
||||
|
||||
auto tableSize = 1 << inputType.getWidth();
|
||||
auto isSigned = inputType.isSigned();
|
||||
|
||||
return std::make_unique<TluIdexer>(
|
||||
TluIdexer(tableSize, isSigned, tableContent));
|
||||
};
|
||||
|
||||
int64_t get(int64_t index, int64_t position) const override {
|
||||
return tableContent[sanitizeIndex(index)];
|
||||
}
|
||||
};
|
||||
|
||||
struct MappedTluIdexer : public Indexer {
|
||||
std::vector<int64_t> tablesContent;
|
||||
std::vector<int64_t> mapContent;
|
||||
|
||||
MappedTluIdexer(int64_t tableSize, bool isSigned,
|
||||
std::vector<int64_t> tablesContent,
|
||||
std::vector<int64_t> mapContent)
|
||||
:
|
||||
|
||||
Indexer{tableSize, isSigned}, tablesContent{std::move(tablesContent)},
|
||||
mapContent{std::move(mapContent)} {}
|
||||
|
||||
~MappedTluIdexer() override = default;
|
||||
|
||||
static std::optional<std::unique_ptr<MappedTluIdexer>>
|
||||
create(ApplyMappedLookupTableEintOp operation) {
|
||||
auto tablesValue = operation.getLuts();
|
||||
|
||||
auto tablesOperation = llvm::dyn_cast_or_null<arith::ConstantOp>(
|
||||
tablesValue.getDefiningOp());
|
||||
if (!tablesOperation) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto tablesContentAttr =
|
||||
tablesOperation.getValueAttr()
|
||||
.dyn_cast_or_null<mlir::DenseIntElementsAttr>();
|
||||
if (!tablesContentAttr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto tablesContent = std::vector<int64_t>();
|
||||
for (auto value : tablesContentAttr.getValues<int64_t>()) {
|
||||
tablesContent.push_back(value);
|
||||
}
|
||||
|
||||
auto mapValue = operation.getMap();
|
||||
|
||||
auto mapOperation =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(mapValue.getDefiningOp());
|
||||
if (!mapOperation) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto mapContentAttr = mapOperation.getValueAttr()
|
||||
.dyn_cast_or_null<mlir::DenseIntElementsAttr>();
|
||||
if (!mapContentAttr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto mapContent = std::vector<int64_t>();
|
||||
for (auto value : mapContentAttr.getValues<int64_t>()) {
|
||||
mapContent.push_back(value);
|
||||
}
|
||||
|
||||
auto inputValue = operation.getT();
|
||||
auto inputType = inputValue.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType()
|
||||
.dyn_cast<FheIntegerInterface>();
|
||||
|
||||
auto tableSize = 1 << inputType.getWidth();
|
||||
auto isSigned = inputType.isSigned();
|
||||
|
||||
return std::make_unique<MappedTluIdexer>(
|
||||
MappedTluIdexer(tableSize, isSigned, tablesContent, mapContent));
|
||||
}
|
||||
|
||||
int64_t get(int64_t index, int64_t position) const override {
|
||||
int64_t tableIndex = mapContent[position];
|
||||
return tablesContent[sanitizeIndex(index) + (tableIndex * tableSize)];
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Indexer> intermediateIndexer;
|
||||
if (intermediateOperationAsTlu) {
|
||||
auto indexer = TluIdexer::create(intermediateOperationAsTlu);
|
||||
if (!indexer) {
|
||||
return std::nullopt;
|
||||
}
|
||||
intermediateIndexer = std::move(*indexer);
|
||||
} else {
|
||||
auto indexer = MappedTluIdexer::create(intermediateOperationAsMappedTlu);
|
||||
if (!indexer) {
|
||||
return std::nullopt;
|
||||
}
|
||||
intermediateIndexer = std::move(*indexer);
|
||||
}
|
||||
|
||||
std::unique_ptr<Indexer> currentIndexer;
|
||||
if (currentOperationAsTlu) {
|
||||
auto indexer = TluIdexer::create(currentOperationAsTlu);
|
||||
if (!indexer) {
|
||||
return std::nullopt;
|
||||
}
|
||||
currentIndexer = std::move(*indexer);
|
||||
} else {
|
||||
auto indexer = MappedTluIdexer::create(currentOperationAsMappedTlu);
|
||||
if (!indexer) {
|
||||
return std::nullopt;
|
||||
}
|
||||
currentIndexer = std::move(*indexer);
|
||||
}
|
||||
|
||||
auto usersOfPreviousOperation = intermediateOperation->getUsers();
|
||||
auto numberOfUsersOfPreviousOperation = std::distance(
|
||||
usersOfPreviousOperation.begin(), usersOfPreviousOperation.end());
|
||||
|
||||
if (numberOfUsersOfPreviousOperation > 1) {
|
||||
// This is a special case.
|
||||
//
|
||||
// Imagine you have this structure:
|
||||
// -----------------
|
||||
// x: uint6
|
||||
// y: uint3 = tlu[x]
|
||||
// z: uint3 = y + 1
|
||||
// a: uint3 = tlu[y]
|
||||
// b: uint3 = a + z
|
||||
// -----------------
|
||||
//
|
||||
// In this case, it might be better not to fuse `a = tlu[tlu[x]]`.
|
||||
//
|
||||
// The reason is, intermediate `y` is necessary for `z`,
|
||||
// so it have to be computed anyway.
|
||||
//
|
||||
// So to calculate `a`, there are 2 options:
|
||||
// - fused tlu on x
|
||||
// - regular tlu on y
|
||||
//
|
||||
// In this case, it's best to fuse only if
|
||||
// bit width of `x` is smaller than bit width of `y`.
|
||||
|
||||
// We can use the table size as it's derived from the bit width
|
||||
// and it preserves the ordering.
|
||||
auto xTableSize = intermediateIndexer->tableSize;
|
||||
auto yTableSize = currentIndexer->tableSize;
|
||||
|
||||
auto shouldFuse = xTableSize < yTableSize;
|
||||
if (!shouldFuse) {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
auto resultingType =
|
||||
(currentOperationAsTlu ? (currentOperationAsTlu.getType())
|
||||
: (currentOperationAsMappedTlu.getType()));
|
||||
|
||||
if (intermediateOperationAsTlu && currentOperationAsTlu) {
|
||||
auto newTableContent = std::vector<int64_t>();
|
||||
newTableContent.reserve(intermediateIndexer->tableSize);
|
||||
|
||||
if (!intermediateIndexer->isSigned) {
|
||||
for (ssize_t x = 0; x < intermediateIndexer->tableSize; x++) {
|
||||
auto resultOfFirstTableLookup = intermediateIndexer->get(x, 0);
|
||||
newTableContent.push_back(
|
||||
currentIndexer->get(resultOfFirstTableLookup, 0));
|
||||
}
|
||||
} else {
|
||||
for (ssize_t x = 0; x < intermediateIndexer->tableSize / 2; x++) {
|
||||
auto resultOfFirstTableLookup = intermediateIndexer->get(x, 0);
|
||||
newTableContent.push_back(
|
||||
currentIndexer->get(resultOfFirstTableLookup, 0));
|
||||
}
|
||||
for (ssize_t x = -(intermediateIndexer->tableSize / 2); x < 0; x++) {
|
||||
auto resultOfFirstTableLookup = intermediateIndexer->get(x, 0);
|
||||
newTableContent.push_back(
|
||||
currentIndexer->get(resultOfFirstTableLookup, 0));
|
||||
}
|
||||
}
|
||||
|
||||
auto newTableShape = std::vector<int64_t>{intermediateIndexer->tableSize};
|
||||
auto newTableType = RankedTensorType::get(
|
||||
newTableShape, IntegerType::get(currentOperation->getContext(), 64));
|
||||
|
||||
auto newTable = rewriter.create<arith::ConstantOp>(
|
||||
currentOperation->getLoc(),
|
||||
DenseIntElementsAttr::get(newTableType, newTableContent));
|
||||
|
||||
auto newOperation = rewriter.create<ApplyLookupTableEintOp>(
|
||||
currentOperation->getLoc(), resultingType, inputValue, newTable);
|
||||
|
||||
return newOperation;
|
||||
}
|
||||
|
||||
auto newTableContents = std::vector<std::vector<int64_t>>();
|
||||
auto newMapContent = std::vector<int64_t>();
|
||||
|
||||
auto inputShape = inputValue.getType().cast<RankedTensorType>().getShape();
|
||||
int64_t numberOfInputs = 1;
|
||||
for (auto dimension : inputShape) {
|
||||
numberOfInputs *= dimension;
|
||||
}
|
||||
|
||||
for (int64_t position = 0; position < numberOfInputs; position++) {
|
||||
auto newTableContent = std::vector<int64_t>();
|
||||
newTableContent.reserve(intermediateIndexer->tableSize);
|
||||
|
||||
if (!intermediateIndexer->isSigned) {
|
||||
for (ssize_t x = 0; x < intermediateIndexer->tableSize; x++) {
|
||||
auto resultOfFirstTableLookup = intermediateIndexer->get(x, position);
|
||||
newTableContent.push_back(
|
||||
currentIndexer->get(resultOfFirstTableLookup, position));
|
||||
}
|
||||
} else {
|
||||
for (ssize_t x = 0; x < intermediateIndexer->tableSize / 2; x++) {
|
||||
auto resultOfFirstTableLookup = intermediateIndexer->get(x, position);
|
||||
newTableContent.push_back(
|
||||
currentIndexer->get(resultOfFirstTableLookup, position));
|
||||
}
|
||||
for (ssize_t x = -(intermediateIndexer->tableSize / 2); x < 0; x++) {
|
||||
auto resultOfFirstTableLookup = intermediateIndexer->get(x, position);
|
||||
newTableContent.push_back(
|
||||
currentIndexer->get(resultOfFirstTableLookup, position));
|
||||
}
|
||||
}
|
||||
|
||||
auto search = std::find(newTableContents.begin(), newTableContents.end(),
|
||||
newTableContent);
|
||||
|
||||
size_t index;
|
||||
if (search == newTableContents.end()) {
|
||||
index = newTableContents.size();
|
||||
newTableContents.push_back(newTableContent);
|
||||
} else {
|
||||
index = std::distance(newTableContents.begin(), search);
|
||||
}
|
||||
|
||||
newMapContent.push_back(index);
|
||||
}
|
||||
|
||||
if (newTableContents.size() == 1) {
|
||||
auto newTableShape = std::vector<int64_t>{intermediateIndexer->tableSize};
|
||||
auto newTableType = RankedTensorType::get(
|
||||
newTableShape, IntegerType::get(currentOperation->getContext(), 64));
|
||||
|
||||
auto newTable = rewriter.create<arith::ConstantOp>(
|
||||
currentOperation->getLoc(),
|
||||
DenseIntElementsAttr::get(newTableType, newTableContents[0]));
|
||||
|
||||
auto newOperation = rewriter.create<ApplyLookupTableEintOp>(
|
||||
currentOperation->getLoc(), resultingType, inputValue, newTable);
|
||||
|
||||
return newOperation;
|
||||
} else {
|
||||
auto newTablesShape =
|
||||
std::vector<int64_t>{static_cast<int64_t>(newTableContents.size()),
|
||||
intermediateIndexer->tableSize};
|
||||
auto newTablesType = RankedTensorType::get(
|
||||
newTablesShape, IntegerType::get(currentOperation->getContext(), 64));
|
||||
|
||||
auto newTableContentsFlattened = std::vector<int64_t>();
|
||||
for (auto newTableContent : newTableContents) {
|
||||
newTableContentsFlattened.insert(newTableContentsFlattened.end(),
|
||||
newTableContent.begin(),
|
||||
newTableContent.end());
|
||||
}
|
||||
|
||||
auto newTables = rewriter.create<arith::ConstantOp>(
|
||||
currentOperation->getLoc(),
|
||||
DenseIntElementsAttr::get(newTablesType, newTableContentsFlattened));
|
||||
|
||||
auto newMapShape = inputShape;
|
||||
auto newMapType = RankedTensorType::get(
|
||||
newMapShape, IndexType::get(currentOperation->getContext()));
|
||||
|
||||
auto newMap = rewriter.create<arith::ConstantOp>(
|
||||
currentOperation->getLoc(),
|
||||
DenseIntElementsAttr::get(newMapType, newMapContent));
|
||||
|
||||
auto newOperation = rewriter.create<ApplyMappedLookupTableEintOp>(
|
||||
currentOperation->getLoc(), resultingType, inputValue, newTables,
|
||||
newMap);
|
||||
|
||||
return newOperation;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void ApplyLookupTableEintOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
|
||||
class AfterTluPattern
|
||||
: public mlir::OpRewritePattern<ApplyLookupTableEintOp> {
|
||||
public:
|
||||
AfterTluPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<ApplyLookupTableEintOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(ApplyLookupTableEintOp currentOperation,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto replacement = fuseBackToBackTableLookups(currentOperation, rewriter);
|
||||
if (replacement) {
|
||||
rewriter.replaceAllUsesWith(currentOperation, *replacement);
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
};
|
||||
patterns.add<AfterTluPattern>(context);
|
||||
}
|
||||
|
||||
void ApplyMappedLookupTableEintOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
|
||||
class AfterTluPattern
|
||||
: public mlir::OpRewritePattern<ApplyMappedLookupTableEintOp> {
|
||||
public:
|
||||
AfterTluPattern(mlir::MLIRContext *context)
|
||||
: mlir::OpRewritePattern<ApplyMappedLookupTableEintOp>(context, 0) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(ApplyMappedLookupTableEintOp currentOperation,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto replacement = fuseBackToBackTableLookups(currentOperation, rewriter);
|
||||
if (replacement) {
|
||||
rewriter.replaceAllUsesWith(currentOperation, *replacement);
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
};
|
||||
patterns.add<AfterTluPattern>(context);
|
||||
}
|
||||
|
||||
} // namespace FHELinalg
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
// RUN: concretecompiler --split-input-file --action=dump-fhe --passes canonicalize %s 2>&1| FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<5> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 2, 4, 8, 12, 18, 24]> : tensor<8xi64>
|
||||
// CHECK-NEXT: %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<5>
|
||||
// CHECK-NEXT: return %0 : !FHE.eint<5>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<5> {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst = arith.constant dense<[0, 1, 4, 9, 16, 25, 36, 49]> : tensor<8xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<3>, tensor<8xi64>) -> !FHE.eint<6>
|
||||
%cst_0 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31]> : tensor<64xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%0, %cst_0) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<5>
|
||||
return %1 : !FHE.eint<5>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<4> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 2, 4]> : tensor<4xi64>
|
||||
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 1, 4, 9]> : tensor<4xi64>
|
||||
// CHECK-NEXT: %0 = "FHE.apply_lookup_table"(%arg0, %cst_0) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %2 = "FHE.add_eint"(%1, %0) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: return %2 : !FHE.eint<4>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<4> {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst = arith.constant dense<[0, 1, 4, 9]> : tensor<4xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<4>
|
||||
%cst_0 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]> : tensor<16xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%0, %cst_0) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
|
||||
%2 = "FHE.add_eint"(%1, %0) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
return %2 : !FHE.eint<4>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]> : tensor<16xi64>
|
||||
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8]> : tensor<64xi64>
|
||||
// CHECK-NEXT: %0 = "FHE.apply_lookup_table"(%arg0, %cst_0) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %1 = "FHE.apply_lookup_table"(%0, %cst) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %2 = "FHE.add_eint"(%0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %3 = "FHE.add_eint"(%1, %2) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
// CHECK-NEXT: return %3 : !FHE.eint<4>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: !FHE.eint<6>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
%cst = arith.constant dense<[0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8]> : tensor<64xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<4>
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst_0 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]> : tensor<16xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%0, %cst_0) : (!FHE.eint<4>, tensor<16xi64>) -> !FHE.eint<4>
|
||||
%2 = "FHE.add_eint"(%0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
%3 = "FHE.add_eint"(%1, %2) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
|
||||
return %3 : !FHE.eint<4>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%arg0: !FHE.esint<5>) -> !FHE.esint<8> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 1, 1, 8, 8, 27, 27, 64, 64, 125, 125, 216, 216, 343, 343, -512, -512, -343, -343, -216, -216, -125, -125, -64, -64, -27, -27, -8, -8, -1, -1]> : tensor<32xi64>
|
||||
// CHECK-NEXT: %0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.esint<8>
|
||||
// CHECK-NEXT: return %0 : !FHE.esint<8>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: !FHE.esint<5>) -> !FHE.esint<8> {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, -8, -8, -7, -7, -6, -6, -5, -5, -4, -4, -3, -3, -2, -2, -1, -1]> : tensor<32xi64>
|
||||
%0 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.esint<4>
|
||||
%c3_i3 = arith.constant 3 : i3
|
||||
%cst_0 = arith.constant dense<[0, 1, 8, 27, 64, 125, 216, 343, -512, -343, -216, -125, -64, -27, -8, -1]> : tensor<16xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%0, %cst_0) : (!FHE.esint<4>, tensor<16xi64>) -> !FHE.esint<8>
|
||||
return %1 : !FHE.esint<8>
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
// RUN: concretecompiler --split-input-file --action=dump-fhe --passes canonicalize %s 2>&1| FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<3>> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<[0, 0, 2, 4]> : tensor<4xi64>
|
||||
// CHECK-NEXT: %0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<4xi64>) -> tensor<2x!FHE.eint<3>>
|
||||
// CHECK-NEXT: return %0 : tensor<2x!FHE.eint<3>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<3>> {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst = arith.constant dense<[0, 1, 4, 9]> : tensor<4xi64>
|
||||
%0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<4xi64>) -> tensor<2x!FHE.eint<4>>
|
||||
%cst_0 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7]> : tensor<16xi64>
|
||||
%1 = "FHELinalg.apply_lookup_table"(%0, %cst_0) : (tensor<2x!FHE.eint<4>>, tensor<16xi64>) -> tensor<2x!FHE.eint<3>>
|
||||
return %1 : tensor<2x!FHE.eint<3>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<4>> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<{{\[\[0, 0, 2, 4\], \[0, 0, 4, 13\]\]}}> : tensor<2x4xi64>
|
||||
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 1]> : tensor<2xindex>
|
||||
// CHECK-NEXT: %0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst, %cst_0) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<4>>
|
||||
// CHECK-NEXT: return %0 : tensor<2x!FHE.eint<4>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<4>> {
|
||||
%cst = arith.constant dense<[0, 1]> : tensor<2xindex>
|
||||
%cst_0 = arith.constant dense<[[0, 1, 4, 9], [0, 1, 8, 27]]> : tensor<2x4xi64>
|
||||
%0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst_0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<5>>
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst_1 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15]> : tensor<32xi64>
|
||||
%1 = "FHELinalg.apply_lookup_table"(%0, %cst_1) : (tensor<2x!FHE.eint<5>>, tensor<32xi64>) -> tensor<2x!FHE.eint<4>>
|
||||
return %1 : tensor<2x!FHE.eint<4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<4>> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<{{\[\[0, 0, 2, 4\], \[0, 0, 4, 13\]\]}}> : tensor<2x4xi64>
|
||||
// CHECK-NEXT: %cst_0 = arith.constant dense<{{\[\[0, 1\], \[1, 0\], \[0, 1\]\]}}> : tensor<3x2xindex>
|
||||
// CHECK-NEXT: %0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst, %cst_0) : (tensor<3x2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!FHE.eint<4>>
|
||||
// CHECK-NEXT: return %0 : tensor<3x2x!FHE.eint<4>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<4>> {
|
||||
%cst = arith.constant dense<[[0, 1], [1, 0], [0, 1]]> : tensor<3x2xindex>
|
||||
%cst_0 = arith.constant dense<[[0, 1, 4, 9], [0, 1, 8, 27]]> : tensor<2x4xi64>
|
||||
%0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst_0, %cst) : (tensor<3x2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!FHE.eint<5>>
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst_1 = arith.constant dense<[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15]> : tensor<32xi64>
|
||||
%1 = "FHELinalg.apply_lookup_table"(%0, %cst_1) : (tensor<3x2x!FHE.eint<5>>, tensor<32xi64>) -> tensor<3x2x!FHE.eint<4>>
|
||||
return %1 : tensor<3x2x!FHE.eint<4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<3>> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<{{\[\[0, 0, 2, 4\], \[0, 0, 1, 3\]\]}}> : tensor<2x4xi64>
|
||||
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 1]> : tensor<2xindex>
|
||||
// CHECK-NEXT: %0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst, %cst_0) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<3>>
|
||||
// CHECK-NEXT: return %0 : tensor<2x!FHE.eint<3>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<3>> {
|
||||
%c2_i3 = arith.constant 2 : i3
|
||||
%cst = arith.constant dense<[0, 1, 4, 9]> : tensor<4xi64>
|
||||
%0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<4xi64>) -> tensor<2x!FHE.eint<4>>
|
||||
%cst_0 = arith.constant dense<[0, 1]> : tensor<2xindex>
|
||||
%cst_1 = arith.constant dense<[[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5]]> : tensor<2x16xi64>
|
||||
%1 = "FHELinalg.apply_mapped_lookup_table"(%0, %cst_1, %cst_0) : (tensor<2x!FHE.eint<4>>, tensor<2x16xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<3>>
|
||||
return %1 : tensor<2x!FHE.eint<3>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<4>> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<{{\[\[0, 0, 1, 3\], \[0, 0, 4, 13\]\]}}> : tensor<2x4xi64>
|
||||
// CHECK-NEXT: %cst_0 = arith.constant dense<[0, 1]> : tensor<2xindex>
|
||||
// CHECK-NEXT: %0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst, %cst_0) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<4>>
|
||||
// CHECK-NEXT: return %0 : tensor<2x!FHE.eint<4>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<4>> {
|
||||
%cst = arith.constant dense<[0, 1]> : tensor<2xindex>
|
||||
%cst_0 = arith.constant dense<[[0, 1, 4, 9], [0, 1, 8, 27]]> : tensor<2x4xi64>
|
||||
%0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst_0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<5>>
|
||||
%cst_1 = arith.constant dense<[[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10], [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15]]> : tensor<2x32xi64>
|
||||
%1 = "FHELinalg.apply_mapped_lookup_table"(%0, %cst_1, %cst) : (tensor<2x!FHE.eint<5>>, tensor<2x32xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<4>>
|
||||
return %1 : tensor<2x!FHE.eint<4>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<10>> {
|
||||
// CHECK-NEXT: %cst = arith.constant dense<[0, 1, 64, 729]> : tensor<4xi64>
|
||||
// CHECK-NEXT: %0 = "FHELinalg.apply_lookup_table"(%arg0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<4xi64>) -> tensor<2x!FHE.eint<10>>
|
||||
// CHECK-NEXT: return %0 : tensor<2x!FHE.eint<10>>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%arg0: tensor<2x!FHE.eint<2>>) -> tensor<2x!FHE.eint<10>> {
|
||||
%cst = arith.constant dense<[0, 1]> : tensor<2xindex>
|
||||
%cst_0 = arith.constant dense<[[0, 1, 4, 9], [0, 1, 8, 27]]> : tensor<2x4xi64>
|
||||
%0 = "FHELinalg.apply_mapped_lookup_table"(%arg0, %cst_0, %cst) : (tensor<2x!FHE.eint<2>>, tensor<2x4xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<5>>
|
||||
%cst_1 = arith.constant dense<[[0, 1, 8, 27, 64, 125, 216, 343, 512, 729, 1000, 1331, 1728, 2197, 2744, 3375, 4096, 4913, 5832, 6859, 8000, 9261, 10648, 12167, 13824, 15625, 17576, 19683, 21952, 24389, 27000, 29791], [0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625, 676, 729, 784, 841, 900, 961]]> : tensor<2x32xi64>
|
||||
%1 = "FHELinalg.apply_mapped_lookup_table"(%0, %cst_1, %cst) : (tensor<2x!FHE.eint<5>>, tensor<2x32xi64>, tensor<2xindex>) -> tensor<2x!FHE.eint<10>>
|
||||
return %1 : tensor<2x!FHE.eint<10>>
|
||||
}
|
||||
Reference in New Issue
Block a user