From 88dd13756acd7befdf741e4b9d0efc4ef384c94b Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 19 Sep 2023 15:27:21 +0100 Subject: [PATCH] feat(compiler): support linalg.generic in the MANP Analysis --- .../concretelang/Dialect/FHE/Analysis/utils.h | 8 + .../include/concretelang/Support/Pipeline.h | 9 +- .../lib/Dialect/FHE/Analysis/MANP.cpp | 244 +++++++++++++++++- .../lib/Dialect/FHE/Analysis/utils.cpp | 17 ++ .../compiler/lib/Support/CompilerEngine.cpp | 31 ++- .../compiler/lib/Support/Pipeline.cpp | 14 +- .../Dialect/FHE/Analysis/MANP.mlir | 98 ++++--- .../Dialect/FHE/Analysis/MANP_conv2d.mlir | 12 +- .../Dialect/FHE/Analysis/MANP_linalg.mlir | 96 +++---- .../Analysis/MANP_linalg_no_canonicalize.mlir | 2 +- .../Dialect/FHE/Analysis/MANP_matmul.mlir | 2 +- .../Dialect/FHE/Analysis/MANP_tensor.mlir | 34 +-- .../Dialect/FHE/optimizer_ast.mlir | 2 +- .../FHELinalg/tensor-ops-to-linalg.mlir | 28 -- 14 files changed, 421 insertions(+), 176 deletions(-) delete mode 100644 compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tensor-ops-to-linalg.mlir diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/utils.h b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/utils.h index 0da6574c2..e6124e878 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/utils.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/FHE/Analysis/utils.h @@ -6,6 +6,7 @@ #ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_UTILS_H #define CONCRETELANG_DIALECT_FHE_ANALYSIS_UTILS_H +#include #include namespace mlir { @@ -16,6 +17,13 @@ namespace utils { bool isEncryptedValue(mlir::Value value); unsigned int getEintPrecision(mlir::Value value); +/// \brief Returns the loop range on a linalg.genric operation. +/// +/// \param op +/// \return llvm::SmallVector +llvm::SmallVector +getLinalgGenericLoopRange(mlir::linalg::GenericOp op); + } // namespace utils } // namespace fhe } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h index 1ff061b3a..4f3eccae3 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/Pipeline.h @@ -40,9 +40,12 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::optional &fheContext, - std::function enablePass, - bool parallelize); + std::function enablePass); + +mlir::LogicalResult +lowerLinalgGenericToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass, + bool parallelizeLoops); mlir::LogicalResult transformFHEBoolean(mlir::MLIRContext &context, mlir::ModuleOp &module, diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 85c8ae9c9..554a51dbd 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #define GEN_PASS_CLASSES #include @@ -737,21 +738,49 @@ public: // Set minimal MANP for encrypted function arguments propagateIfChanged(lattice, lattice->join(MANPLatticeValue{ std::optional{llvm::APInt(1, 1)}})); + } + // In case of block arguments used in the block of a linalg.genric + // operation: map the MANP values of the operands into the block arguments + else if (lattice->getPoint().isa() || + mlir::concretelang::fhe::utils::isEncryptedValue( + lattice->getPoint())) { + mlir::Block *block = + lattice->getPoint().cast().getOwner(); + + if (block && block->getParentOp() && + llvm::isa(block->getParentOp())) { + auto genericOp = + mlir::dyn_cast(block->getParentOp()); + // Get the MANP from the corresponding input/output + auto argIndex = + lattice->getPoint().cast().getArgNumber(); + auto operandRange = genericOp.getInputs(); + if (argIndex >= operandRange.size()) { + argIndex -= operandRange.size(); + operandRange = genericOp.getOutputs(); + } + auto v = operandRange[argIndex]; + auto manp = this->getLatticeElement(v)->getValue().getMANP().value_or( + llvm::APInt(1, 1)); + propagateIfChanged(lattice, lattice->join(MANPLatticeValue{manp})); + } } else { // Everything else is initialized with an unset value propagateIfChanged(lattice, lattice->join(MANPLatticeValue{})); } } - void visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) override { - MANPLattice *latticeRes = results[0]; - + std::optional + norm2SqEquivFromOp(Operation *op, ArrayRef operands) { std::optional norm2SqEquiv; - if (auto cstNoiseOp = llvm::dyn_cast(op)) { - norm2SqEquiv = llvm::APInt{1, 1, false}; + if (llvm::isa(op)) { + norm2SqEquiv = llvm::APInt{1, 0, false}; + } else { + norm2SqEquiv = llvm::APInt{1, 1, false}; + } } else if (llvm::isa(op) || llvm::isa(op)) { norm2SqEquiv = getNoOpSqMANP(operands); @@ -876,8 +905,11 @@ public: norm2SqEquiv = {}; } } - - else if (llvm::isa(op)) { + // Linalg Generic + else if (auto linalgGenericOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = emulateLinalgGenric(linalgGenericOp, operands); + } else if (llvm::isa(op)) { norm2SqEquiv = {}; } else if (llvm::isa( *op->getDialect())) { @@ -886,6 +918,14 @@ public: } else { norm2SqEquiv = {}; } + return norm2SqEquiv; + } + + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override { + MANPLattice *latticeRes = results[0]; + + std::optional norm2SqEquiv = norm2SqEquivFromOp(op, operands); if (norm2SqEquiv.has_value()) { latticeRes->join(MANPLatticeValue{norm2SqEquiv}); @@ -915,6 +955,194 @@ public: } } + /// Compute the flat index of a tensor given its shape, the current loop + /// indices, and the map between loop and tensor indices + size_t indexFromLoopRange(mlir::SmallVector loopIndices, + mlir::AffineMap map, + mlir::ArrayRef shape) { + llvm::SmallVector tensorIndices; + llvm::SmallVector constantIndices; + for (auto i : loopIndices) { + constantIndices.push_back( + IntegerAttr::get(IntegerType::get(map.getContext(), 64), i)); + } + assert(map.constantFold(constantIndices, tensorIndices).succeeded()); + assert(tensorIndices.size() == shape.size()); + + int64_t multiplier = 1; + size_t index = 0; + for (int64_t i = shape.size() - 1; i >= 0; i--) { + index += tensorIndices[i].cast().getInt() * multiplier; + multiplier *= shape[i]; + } + return index; + } + + // Compute the MANP value of a linalg.generic operation by emulating its + // execution + std::optional + emulateLinalgGenric(mlir::linalg::GenericOp genericOp, + llvm::ArrayRef operandMANPs) { + assert(genericOp.getOutputs().size() == 1 && + "MANP doesn't support linalg.genric with more than one output"); + + // We want to use a different mechanism to store MANP values than the + // Analysis. We don't want to write anything to the lattice values + // controlled by the analysis, but we will use them to read values that we + // don't yet have in the emulation + mlir::DenseMap valueToManp; + auto fetchOrFallbackToAnalysis = + [&](mlir::Value value) -> MANPLatticeValue * { + auto elem = valueToManp.find(value); + if (elem != valueToManp.end()) { + return &elem->second; + } + auto manp = getLatticeElement(value)->getValue(); + valueToManp[value] = manp; + return &valueToManp[value]; + }; + auto loopRange = + mlir::concretelang::fhe::utils::getLinalgGenericLoopRange(genericOp); + auto iterCount = std::accumulate(loopRange.begin(), loopRange.end(), 1, + std::multiplies()); + llvm::SmallVector strides; + auto stride = iterCount; + for (size_t i = 0; i < loopRange.size(); i++) { + stride /= loopRange[i]; + strides.push_back(stride); + } + + // clone the genricOp to replace block arguments with constant values when + // needed. The clone op must be destroyed at the end of the function + auto genericOpClone = genericOp.clone(); + + // init block arguments' MANP: map block arguments with op operands + for (auto arg : genericOpClone.getBlock()->getArguments()) { + auto argIndex = arg.getArgNumber(); + auto operandRange = genericOpClone.getInputs(); + if (argIndex >= operandRange.size()) { + argIndex -= operandRange.size(); + operandRange = genericOpClone.getOutputs(); + } + valueToManp[arg] = getLatticeElement(operandRange[argIndex])->getValue(); + } + + // keep track of the MANP of different elements in the output tensor + // (initialized to the initial output MANP value) + auto outputArg = genericOpClone.getBlock()->getArguments().back(); + auto outputType = + genericOpClone.getOutputs().front().getType().cast(); + auto outputSize = std::accumulate(outputType.getShape().begin(), + outputType.getShape().end(), 1, + std::multiplies()); + std::vector outputMANPs( + outputSize, fetchOrFallbackToAnalysis(outputArg)->getMANP().value()); + + // indices at a specific iteration + llvm::SmallVector indices(loopRange.size(), 0); + for (auto i = 0; i < iterCount; i++) { + for (size_t iterPos = 0; iterPos < indices.size(); iterPos++) { + indices[iterPos] = (i / strides[iterPos]) % loopRange[iterPos]; + } + + // if a linalg genric input is constant, replace the uses of its + // respective block argument with a constant value. This avoids the + // computation of the MANP to use conservative values. + // we also have to replace them back for the next iteration to be able to + // update them again with new values + mlir::DenseMap toReplaceBack; + for (auto arg : genericOpClone.getBlock()->getArguments()) { + auto argIndex = arg.getArgNumber(); + auto inputs = genericOpClone.getInputs(); + // don't consider outputs + if (argIndex >= inputs.size()) + continue; + auto input = inputs[argIndex]; + auto definingOp = input.getDefiningOp(); + if (definingOp && mlir::isa(definingOp)) { + // fetch constant value + auto constantOp = mlir::dyn_cast(definingOp); + mlir::DenseIntElementsAttr denseAttr = + constantOp.getValueAttr().dyn_cast(); + auto constantTensor = denseAttr.getValues(); + auto constantIndex = indexFromLoopRange( + indices, genericOpClone.getIndexingMapsArray()[argIndex], + denseAttr.getType().getShape()); + APInt constantValue = constantTensor[constantIndex]; + // create new constant op with constant value + auto opBuilder = mlir::OpBuilder(genericOpClone.getContext()); + auto opState = mlir::OperationState( + mlir::UnknownLoc::get(genericOpClone.getContext()), + arith::ConstantOp::getOperationName()); + arith::ConstantOp::build( + opBuilder, opState, + mlir::IntegerAttr::get(denseAttr.getType().getElementType(), + constantValue)); + auto newConstantOp = + arith::ConstantOp(mlir::Operation::create(opState)); + genericOpClone.getBlock()->push_front(newConstantOp); + // replace uses of the block argument with the constant value + arg.replaceAllUsesWith(newConstantOp.getResult()); + toReplaceBack[arg] = newConstantOp; + } + } + + // we want to replace the MANP of the block argument corresponding to the + // output with the MANP value corresponding to the currently accessed + // tensor element + size_t outputIndex = indexFromLoopRange( + indices, + genericOpClone.getIndexingMapsArray()[outputArg.getArgNumber()], + outputType.getShape()); + valueToManp[outputArg] = MANPLatticeValue(outputMANPs[outputIndex]); + genericOpClone.getBody()->walk([&](mlir::Operation *op) { + // we update the appropriate element's MANP value using the index of the + // currently accessed output element + if (auto yieldOp = mlir::dyn_cast(op)) { + auto manp = fetchOrFallbackToAnalysis(yieldOp->getOperand(0)); + outputMANPs[outputIndex] = manp->getMANP().value(); + return; + } + // compute using the op and operand manp values + mlir::SmallVector latticeOperands; + for (auto operand : op->getOperands()) { + auto lattice = new MANPLattice(operand); + lattice->join(*fetchOrFallbackToAnalysis(operand)); + latticeOperands.push_back(lattice); + } + std::optional norm2SqEquiv = + norm2SqEquivFromOp(op, latticeOperands); + // update the MANP of the result value + if (op->getNumResults() > 0) { + valueToManp[op->getResult(0).cast()] = + MANPLatticeValue(norm2SqEquiv); + } + // free space of lattice elements + for (auto toFree : latticeOperands) { + delete toFree; + } + }); + + // replace back the uses of block arguments which were replaced by + // constant values + for (auto replacement : toReplaceBack) { + auto blockArg = replacement.first; + auto constantOp = replacement.second; + auto constantValue = constantOp.getResult(); + constantValue.replaceAllUsesWith(blockArg); + constantOp->remove(); + constantOp->destroy(); + } + } + genericOpClone->destroy(); + // final result MANP is the max of output + llvm::APInt result = outputMANPs[0]; + for (auto manp : outputMANPs) { + result = APIntUMax(result, manp); + } + return result; + } + private: bool debug; }; diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/utils.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/utils.cpp index 239cdfb9d..5261c94a7 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/utils.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/utils.cpp @@ -5,6 +5,7 @@ #include "concretelang/Dialect/FHE/IR/FHETypes.h" #include +#include namespace mlir { namespace concretelang { @@ -52,6 +53,22 @@ unsigned int getEintPrecision(mlir::Value value) { return 0; } +llvm::SmallVector +getLinalgGenericLoopRange(mlir::linalg::GenericOp op) { + uint64_t loopRangeDim = op.getLoopsToShapesMap().getNumDims(); + llvm::SmallVector loopRange; + for (uint64_t i = 0; i < loopRangeDim; i++) { + mlir::Value mappedValue; + unsigned int pos; + assert( + op.mapIterationSpaceDimToOperandDim(i, mappedValue, pos).succeeded() && + "couldn't compute loop range"); + loopRange.push_back( + mappedValue.getType().cast().getShape()[pos]); + } + return loopRange; +} + } // namespace utils } // namespace fhe } // namespace concretelang diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index 0601a42ba..e8f676339 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -366,19 +366,12 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, return StreamStringError("Tiling of FHELinalg operations failed"); } - // Dataflow parallelization - if (dataflowParallelize && - mlir::concretelang::pipeline::autopar(mlirContext, module, enablePass) - .failed()) { - return StreamStringError("Dataflow parallelization failed"); - } - if (target == Target::FHE) return std::move(res); // FHELinalg -> FHE - if (mlir::concretelang::pipeline::lowerFHELinalgToFHE( - mlirContext, module, res.fheContext, enablePass, loopParallelize) + if (mlir::concretelang::pipeline::lowerFHELinalgToFHE(mlirContext, module, + enablePass) .failed()) { return StreamStringError("Lowering from FHELinalg to FHE failed"); } @@ -389,9 +382,29 @@ CompilerEngine::compile(mlir::ModuleOp moduleOp, Target target, return StreamStringError("Rewriting of high level fhe ops failed"); } + // TODO: bring determineFHEParameters call here after the FHELinalg -> FHE + // lowering + // require to first support linalg.genric in the Optimizer Dag creation + // FHE High level pass to determine FHE parameters + // if (auto err = this->determineFHEParameters(res)) + // return std::move(err); + if (target == Target::FHE_NO_LINALG) return std::move(res); + // Dataflow parallelization + if (dataflowParallelize && + mlir::concretelang::pipeline::autopar(mlirContext, module, enablePass) + .failed()) { + return StreamStringError("Dataflow parallelization failed"); + } + + if (mlir::concretelang::pipeline::lowerLinalgGenericToLoops( + mlirContext, module, enablePass, loopParallelize) + .failed()) { + return StreamStringError("Lowering from Linalg Generic to Loops failed"); + } + // FHE -> TFHE if (mlir::concretelang::pipeline::lowerFHEToTFHE(mlirContext, module, res.fheContext, enablePass) diff --git a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp index 228fc232a..c66579d58 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/Pipeline.cpp @@ -196,15 +196,23 @@ transformHighLevelFHEOps(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::LogicalResult lowerFHELinalgToFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::optional &fheContext, - std::function enablePass, - bool parallelizeLoops) { + std::function enablePass) { mlir::PassManager pm(&context); pipelinePrinting("FHELinalgToFHE", pm, context); addPotentiallyNestedPass( pm, mlir::concretelang::createConvertFHETensorOpsToLinalg(), enablePass); addPotentiallyNestedPass(pm, mlir::createLinalgGeneralizationPass(), enablePass); + + return pm.run(module.getOperation()); +} + +mlir::LogicalResult +lowerLinalgGenericToLoops(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass, + bool parallelizeLoops) { + mlir::PassManager pm(&context); + pipelinePrinting("LinalgGenericToLoops", pm, context); addPotentiallyNestedPass( pm, mlir::concretelang::createLinalgGenericOpWithTensorsToLoopsPass( diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir index eadd19c1d..2017726a9 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP.mlir @@ -1,8 +1,8 @@ -// RUN: concretecompiler --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --passes fhe-tensor-ops-to-linalg --passes MANP --passes ConcreteOptimizer --action=dump-fhe-no-linalg --split-input-file %s 2>&1 | FileCheck %s func.func @single_zero() -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.zero"() {MANP = 1 : ui{{[[0-9]+}}} : () -> !FHE.eint<2> + // CHECK: MANP = 0 : ui{{[0-9]+}} %0 = "FHE.zero"() : () -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -12,7 +12,7 @@ func.func @single_zero() -> !FHE.eint<2> func.func @zero() -> tensor<8x!FHE.eint<2>> { - // CHECK: %[[ret:.*]] = "FHE.zero_tensor"() {MANP = 1 : ui{{[0-9]+}}} : () -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 0 : ui{{[0-9]+}} %0 = "FHE.zero_tensor"() : () -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -24,7 +24,7 @@ func.func @single_cst_add_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant 3 : i3 - // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.add_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -36,7 +36,7 @@ func.func @single_cst_add_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.add_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -46,7 +46,7 @@ func.func @single_cst_add_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_add_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.add_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -56,7 +56,7 @@ func.func @single_dyn_add_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> func.func @single_add_eint(%e0: !FHE.eint<2>, %e1: !FHE.eint<2>) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.add_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 2 : ui{{[0-9]+}} %0 = "FHE.add_eint"(%e0, %e1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -68,7 +68,7 @@ func.func @single_cst_sub_int_eint(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant 3 : i3 - // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.sub_int_eint"(%cst, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -80,7 +80,7 @@ func.func @single_cst_sub_int_eint_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.sub_int_eint"(%cst, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -90,7 +90,7 @@ func.func @single_cst_sub_int_eint_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_sub_int_eint(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.sub_int_eint"(%i, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -102,7 +102,7 @@ func.func @single_cst_sub_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant 3 : i3 - // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -114,7 +114,7 @@ func.func @single_cst_sub_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -124,7 +124,7 @@ func.func @single_cst_sub_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_sub_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.sub_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -134,16 +134,16 @@ func.func @single_dyn_sub_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> func.func @chain_sub_eint(%e0: !FHE.eint<2>, %e1: !FHE.eint<2>, %e2: !FHE.eint<2>, %e3: !FHE.eint<2>, %e4: !FHE.eint<2>) -> !FHE.eint<2> { - // CHECK: %[[V0:.*]] = "FHE.sub_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 2 : ui{{[0-9]+}} %0 = "FHE.sub_eint"(%e0, %e1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> - // CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%[[V0]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 2 : ui{{[0-9]+}} %1 = "FHE.sub_eint"(%0, %e2) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> - // CHECK-NEXT: %[[V2:.*]] = "FHE.sub_eint"(%[[V1]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 2 : ui{{[0-9]+}} %2 = "FHE.sub_eint"(%1, %e3) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> - // CHECK-NEXT: %[[V3:.*]] = "FHE.sub_eint"(%[[V2]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 3 : ui{{[0-9]+}} %3 = "FHE.sub_eint"(%2, %e4) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> return %3 : !FHE.eint<2> @@ -153,7 +153,7 @@ func.func @chain_sub_eint(%e0: !FHE.eint<2>, %e1: !FHE.eint<2>, %e2: !FHE.eint<2 func.func @single_neg_eint(%e: !FHE.eint<2>) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.neg_eint"(%e) : (!FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -165,7 +165,7 @@ func.func @single_cst_mul_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant 3 : i3 - // CHECK: %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHE.mul_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -177,7 +177,7 @@ func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -3 : i3 - // CHECK: %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHE.mul_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -189,7 +189,7 @@ func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst = arith.constant -1 : i3 - // CHECK: %0 = "FHE.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.mul_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -199,7 +199,7 @@ func.func @single_cst_mul_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> func.func @single_dyn_mul_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHE.mul_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -208,7 +208,7 @@ func.func @single_dyn_mul_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> // ----- func.func @single_apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.apply_lookup_table"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> return %1: !FHE.eint<2> } @@ -222,13 +222,13 @@ func.func @chain_add_eint_int(%e: !FHE.eint<3>) -> !FHE.eint<3> %cst2 = arith.constant 2 : i4 %cst3 = arith.constant 1 : i4 - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.add_eint_int"(%e, %cst0) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint_int"(%[[V0]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %1 = "FHE.add_eint_int"(%0, %cst1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%[[V1]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %2 = "FHE.add_eint_int"(%1, %cst2) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint_int"(%[[V2]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %3 = "FHE.add_eint_int"(%2, %cst3) : (!FHE.eint<3>, i4) -> !FHE.eint<3> return %3 : !FHE.eint<3> @@ -243,13 +243,13 @@ func.func @dag_add_eint_int(%e: !FHE.eint<3>) -> !FHE.eint<3> %Acst2 = arith.constant 2 : i4 %Acst3 = arith.constant 1 : i4 - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %A0 = "FHE.add_eint_int"(%e, %Acst0) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint_int"(%[[V0]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %A1 = "FHE.add_eint_int"(%A0, %Acst1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%[[V1]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %A2 = "FHE.add_eint_int"(%A1, %Acst2) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint_int"(%[[V2]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %A3 = "FHE.add_eint_int"(%A2, %Acst3) : (!FHE.eint<3>, i4) -> !FHE.eint<3> %Bcst0 = arith.constant 1 : i4 @@ -259,20 +259,20 @@ func.func @dag_add_eint_int(%e: !FHE.eint<3>) -> !FHE.eint<3> %Bcst4 = arith.constant 4 : i4 %Bcst5 = arith.constant 7 : i4 - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %B0 = "FHE.add_eint_int"(%e, %Bcst0) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint_int"(%[[V0]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %B1 = "FHE.add_eint_int"(%B0, %Bcst1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%[[V1]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %B2 = "FHE.add_eint_int"(%B1, %Bcst2) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint_int"(%[[V2]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %B3 = "FHE.add_eint_int"(%B2, %Bcst3) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V4:.*]] = "FHE.add_eint_int"(%[[V3]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %B4 = "FHE.add_eint_int"(%B3, %Bcst4) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V5:.*]] = "FHE.add_eint_int"(%[[V4]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<3>, i4) -> !FHE.eint<3> + // CHECK: MANP = 1 : ui{{[0-9]+}} %B5 = "FHE.add_eint_int"(%B4, %Bcst5) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - // CHECK-NEXT: %[[V6:.*]] = "FHE.add_eint"(%[[V5]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> + // CHECK: MANP = 2 : ui{{[0-9]+}} %res = "FHE.add_eint"(%B5, %A3) : (!FHE.eint<3>, !FHE.eint<3>) -> !FHE.eint<3> return %A3 : !FHE.eint<3> @@ -282,16 +282,16 @@ func.func @dag_add_eint_int(%e: !FHE.eint<3>) -> !FHE.eint<3> func.func @chain_add_eint(%e0: !FHE.eint<2>, %e1: !FHE.eint<2>, %e2: !FHE.eint<2>, %e3: !FHE.eint<2>, %e4: !FHE.eint<2>) -> !FHE.eint<2> { - // CHECK: %[[V0:.*]] = "FHE.add_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 2 : ui{{[0-9]+}} %0 = "FHE.add_eint"(%e0, %e1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> - // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%[[V0]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 2 : ui{{[0-9]+}} %1 = "FHE.add_eint"(%0, %e2) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> - // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint"(%[[V1]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 2 : ui{{[0-9]+}} %2 = "FHE.add_eint"(%1, %e3) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> - // CHECK-NEXT: %[[V3:.*]] = "FHE.add_eint"(%[[V2]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 3 : ui{{[0-9]+}} %3 = "FHE.add_eint"(%2, %e4) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> return %3 : !FHE.eint<2> @@ -304,9 +304,9 @@ func.func @chain_add_eint_neg_eint(%e: !FHE.eint<2>) -> !FHE.eint<2> { %cst0 = arith.constant 3 : i3 - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHE.add_eint_int"(%e, %cst0) : (!FHE.eint<2>, i3) -> !FHE.eint<2> - // CHECK-NEXT: %[[ret:.*]] = "FHE.neg_eint"(%[[V0]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>) -> !FHE.eint<2> + // CHECK: MANP = 1 : ui{{[0-9]+}} %1 = "FHE.neg_eint"(%0) : (!FHE.eint<2>) -> !FHE.eint<2> return %1 : !FHE.eint<2> @@ -316,9 +316,7 @@ func.func @chain_add_eint_neg_eint(%e: !FHE.eint<2>) -> !FHE.eint<2> // CHECK-LABEL: @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> func.func @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> { - // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.transpose"(%arg0) {MANP = 1 : ui1, axes = []} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> - // CHECK-NEXT: return %[[v0]] : tensor<5x4x3x!FHE.eint<6>> - // CHECK-NEXT: } + // CHECK: MANP = 1 : ui{{[0-9]+}} %c = "FHELinalg.transpose"(%arg0) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> return %c : tensor<5x4x3x!FHE.eint<6>> } @@ -327,11 +325,9 @@ func.func @transpose_eint_3D(%arg0: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x // CHECK-LABEL: @transpose_eint_3D_after_op(%arg0: tensor<3x4x5x!FHE.eint<6>>, %arg1: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> func.func @transpose_eint_3D_after_op(%arg0: tensor<3x4x5x!FHE.eint<6>>, %arg1: tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> { - // CHECK-NEXT: %[[v0:.*]] = "FHELinalg.add_eint"(%arg0, %arg1) {MANP = 2 : ui{{[0-9]+}}} : (tensor<3x4x5x!FHE.eint<6>>, tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x4x5x!FHE.eint<6>> - // CHECK-NEXT: %[[v1:.*]] = "FHELinalg.transpose"(%[[v0]]) {MANP = 2 : ui{{[0-9]+}}, axes = []} : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> - // CHECK-NEXT: return %[[v1]] : tensor<5x4x3x!FHE.eint<6>> - // CHECK-NEXT: } + // CHECK: MANP = 2 : ui{{[0-9]+}} %sum = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x4x5x!FHE.eint<6>>, tensor<3x4x5x!FHE.eint<6>>) -> tensor<3x4x5x!FHE.eint<6>> + // CHECK: MANP = 2 : ui{{[0-9]+}} %c = "FHELinalg.transpose"(%sum) : (tensor<3x4x5x!FHE.eint<6>>) -> tensor<5x4x3x!FHE.eint<6>> return %c : tensor<5x4x3x!FHE.eint<6>> } diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir index b9de103ec..c7a0e4845 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir @@ -1,9 +1,9 @@ -// RUN: concretecompiler --passes canonicalize --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --passes canonicalize --passes linalg-generalize-named-ops --passes fhe-tensor-ops-to-linalg --passes MANP --passes ConcreteOptimizer --action=dump-fhe-no-linalg --split-input-file %s 2>&1 | FileCheck %s func.func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> { %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> %bias = arith.constant dense<[5]> : tensor<1xi7> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 4 : ui{{[0-9]+}} + // CHECK: MANP = 4 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> @@ -14,7 +14,7 @@ func.func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> { %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 4 : ui{{[0-9]+}} + // CHECK: MANP = 4 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> @@ -25,7 +25,7 @@ func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : ten func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { %bias = arith.constant dense<[5]> : tensor<1xi3> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 6 : ui{{[0-9]+}} + // CHECK: MANP = 6 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> @@ -35,7 +35,7 @@ func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tens // ----- func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 6 : ui{{[0-9]+}} + // CHECK: MANP = 6 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> @@ -45,7 +45,7 @@ func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weigh // ----- func.func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> { - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 11 : ui{{[0-9]+}} + // CHECK: MANP = 11 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir index 4cf20197b..5b14fe85f 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir @@ -1,10 +1,10 @@ -// RUN: concretecompiler --passes canonicalize --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --passes canonicalize --passes fhe-tensor-ops-to-linalg --passes MANP --passes ConcreteOptimizer --action=dump-fhe-no-linalg --split-input-file %s 2>&1 | FileCheck %s func.func @single_cst_add_eint_int(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { %cst = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.add_eint_int"(%t, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -17,7 +17,7 @@ func.func @single_cst_add_eint_int_from_cst_elements(%t: tensor<8x!FHE.eint<2>>) %cst1 = arith.constant 1 : i3 %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.add_eint_int"(%t, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -26,7 +26,7 @@ func.func @single_cst_add_eint_int_from_cst_elements(%t: tensor<8x!FHE.eint<2>>) // ----- func.func @single_dyn_add_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>> { - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.add_eint_int"(%e, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -36,7 +36,7 @@ func.func @single_dyn_add_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) func.func @single_add_eint(%e0: tensor<8x!FHE.eint<2>>, %e1: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 2 : ui{{[0-9]+}} %0 = "FHELinalg.add_eint"(%e0, %e1) : (tensor<8x!FHE.eint<2>>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -48,7 +48,7 @@ func.func @single_cst_sub_int_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE. { %cst = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.sub_int_eint"(%cst, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -61,7 +61,7 @@ func.func @single_cst_sub_int_eint_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) %cst1 = arith.constant 1 : i3 %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.sub_int_eint"(%cst, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -73,7 +73,7 @@ func.func @single_cst_sub_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE. { %cst = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.sub_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -86,7 +86,7 @@ func.func @single_cst_sub_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) %cst1 = arith.constant 1 : i3 %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.sub_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -96,7 +96,7 @@ func.func @single_cst_sub_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) func.func @single_sub_eint(%e0: tensor<8x!FHE.eint<2>>, %e1: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { - // CHECK: %[[ret:.*]] = "FHELinalg.sub_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 2 : ui{{[0-9]+}} %0 = "FHELinalg.sub_eint"(%e0, %e1) : (tensor<8x!FHE.eint<2>>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -106,7 +106,7 @@ func.func @single_sub_eint(%e0: tensor<8x!FHE.eint<2>>, %e1: tensor<8x!FHE.eint< func.func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { - // CHECK: %[[ret:.*]] = "FHELinalg.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.neg_eint"(%e) : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -117,7 +117,7 @@ func.func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> func.func @single_dyn_sub_int_eint(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>> { // sqrt(1 + (2^2-1)^2) = 3.16 - // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.sub_int_eint"(%i, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -129,7 +129,7 @@ func.func @single_cst_mul_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE. { %cst = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> - // %0 = "FHELinalg.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -142,7 +142,7 @@ func.func @single_cst_mul_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) %cst1 = arith.constant 2 : i3 %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> - // %0 = "FHELinalg.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 2 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -153,7 +153,7 @@ func.func @single_cst_mul_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) func.func @single_dyn_mul_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>> { // sqrt(1 * (2^2-1)^2) = 3.16 - // CHECK: %[[ret:.*]] = "FHELinalg.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%e, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -167,13 +167,13 @@ func.func @chain_add_eint_int(%e: tensor<8x!FHE.eint<3>>) -> tensor<8x!FHE.eint< %cst1 = arith.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi4> %cst2 = arith.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi4> %cst3 = arith.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi4> - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %1 = "FHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %2 = "FHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %3 = "FHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> return %3 : tensor<8x!FHE.eint<3>> } @@ -183,9 +183,9 @@ func.func @chain_add_eint_int(%e: tensor<8x!FHE.eint<3>>) -> tensor<8x!FHE.eint< func.func @chain_add_eint_int_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { %cst0 = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = "FHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %1 = "FHELinalg.neg_eint"(%0) : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %1 : tensor<8x!FHE.eint<2>> } @@ -198,7 +198,7 @@ func.func @chain_add_eint_int_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x! func.func @apply_lookup_table(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3>> { %lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64> - // CHECK: %[[RES:.*]] = "FHELinalg.apply_lookup_table"(%[[T:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %res = "FHELinalg.apply_lookup_table"(%t, %lut) : (tensor<3x3x!FHE.eint<2>>, tensor<4xi64>) -> tensor<3x3x!FHE.eint<3>> return %res : tensor<3x3x!FHE.eint<3>> } @@ -207,9 +207,9 @@ func.func @apply_lookup_table(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.e func.func @apply_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<3>> { %lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64> - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - // CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %res = "FHELinalg.apply_lookup_table"(%0, %lut) : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>> return %res : tensor<8x!FHE.eint<3>> } @@ -218,7 +218,7 @@ func.func @apply_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8x func.func @apply_multi_lookup_table(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<3x3x4xi64>) -> tensor<3x3x!FHE.eint<3>> { - // CHECK: %[[RES:.*]] = "FHELinalg.apply_multi_lookup_table"(%[[T:.*]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<3x3x!FHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!FHE.eint<3>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %res = "FHELinalg.apply_multi_lookup_table"(%t, %luts) : (tensor<3x3x!FHE.eint<2>>, tensor<3x3x4xi64>) -> tensor<3x3x!FHE.eint<3>> return %res : tensor<3x3x!FHE.eint<3>> } @@ -226,9 +226,9 @@ func.func @apply_multi_lookup_table(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor< // ----- func.func @apply_multi_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>, %luts: tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> { - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - // CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_multi_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %res = "FHELinalg.apply_multi_lookup_table"(%0, %luts) : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> return %res : tensor<8x!FHE.eint<3>> } @@ -243,7 +243,7 @@ func.func @single_cst_dot(%t: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2> { %cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3> // sqrt(1^2*1 + 2^2*1 + 3^2*1 + 4^2*1) = 5.477225575 - // CHECK: %[[V0:.*]] = "FHELinalg.dot_eint_int"(%[[T:.*]], %[[CST:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> + // CHECK: MANP = 6 : ui{{[0-9]+}} %0 = "FHELinalg.dot_eint_int"(%t, %cst) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %0 : !FHE.eint<2> } @@ -254,7 +254,7 @@ func.func @single_cst_dot(%t: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2> func.func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FHE.eint<2> { // sqrt(1^2*(2^2-1)^2*4) = 6 - // CHECK: %[[V0:.*]] = "FHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> + // CHECK: MANP = 6 : ui{{[0-9]+}} %0 = "FHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -265,12 +265,12 @@ func.func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FH func.func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2> { // sqrt((2^2-1)^2*1) = sqrt(9) = 3 - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> %cst = arith.constant dense<[1, 2, 3, -1]> : tensor<4xi3> // sqrt(1^2*9 + 2^2*9 + 3^2*9 + 1^2*9) = sqrt(135) = 12 - // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 12 : ui{{[[0-9]+}}} + // CHECK: MANP = 12 : ui{{[0-9]+}} %1 = "FHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %1 : !FHE.eint<2> @@ -281,11 +281,11 @@ func.func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) func.func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2> { // sqrt((2^2-1)^2*1) = sqrt(9) = 3 - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> // sqrt(3^2*(2^2-1)^2*4) = sqrt(324) = 18 - // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I:.*]]) {MANP = 18 : ui{{[0-9]+}}} + // CHECK: MANP = 18 : ui{{[0-9]+}} %1 = "FHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %1 : !FHE.eint<2> @@ -303,7 +303,7 @@ func.func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tenso // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 // manp(add_eint(mul, acc)) = 9 // ceil(sqrt(9)) = 3 - // CHECK: %[[V0:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %0 : tensor<3x2x!FHE.eint<2>> } @@ -319,7 +319,7 @@ func.func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!FHE.eint<2>>, %arg1: tenso // manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 // manp(add_eint(mul, acc)) = 9 + 9 = 18 // ceil(sqrt(18)) = 5 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} + // CHECK: MANP = 5 : ui{{[0-9]+}} %1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -333,7 +333,7 @@ func.func @matmul_eint_int_cst_p_1(%arg0: tensor<3x1x!FHE.eint<2>>) -> tensor<3x // mul = manp(mul_eint_int(eint<2>, 3) = 1 * 3^2 = 9 // manp(add_eint(mul, acc)) = 9 // ceil(sqrt(10)) = 3 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %1 = "FHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -351,7 +351,7 @@ func.func @matmul_eint_int_cst_p_2_n_0(%arg0: tensor<3x2x!FHE.eint<2>>) -> tenso // mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 16 // manp(add_eint(mul, acc)) = 16 + 9 = 25 // ceil(sqrt(25)) = 5 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} + // CHECK: MANP = 5 : ui{{[0-9]+}} %1 = "FHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -369,7 +369,7 @@ func.func @matmul_eint_int_cst_p_2_n_1(%arg0: tensor<3x2x!FHE.eint<2>>) -> tenso // mul = manp(mul_eint_int(eint<2>, 1) = 1 * 1^2 = 1 // manp(add_eint(mul, acc)) = 1 + 16 = 17 // ceil(sqrt(17)) = 5 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} + // CHECK: MANP = 5 : ui{{[0-9]+}} %1 = "FHELinalg.matmul_eint_int"(%arg0, %0): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -552,11 +552,11 @@ func.func @matmul_eint_int_cst(%0: tensor<4x3x!FHE.eint<7>>) -> (tensor<4x!FHE.e // ----- func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x!FHE.eint<7>> { - // CHECK: {MANP = 1 : ui{{[0-9]+}}} + // CHECK: {MANP = 0 : ui{{[0-9]+}}} %z = "FHE.zero_tensor"() : () -> tensor<4x3x!FHE.eint<7>> %a = arith.constant dense<[[4, 6, 5], [2, 6, 3], [5, 6, 1], [5, 5, 3]]> : tensor<4x3xi8> - // CHECK: {MANP = 1 : ui{{[0-9]+}}} + // CHECK: {MANP = 0 : ui{{[0-9]+}}} %0 = "FHELinalg.add_eint_int"(%z, %a) : (tensor<4x3x!FHE.eint<7>>, tensor<4x3xi8>) -> tensor<4x3x!FHE.eint<7>> // =============================== @@ -566,7 +566,7 @@ func.func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x!FHE.eint<7> [2, 1, 5] > : tensor<3xi8> - // CHECK: MANP = 6 : ui{{[0-9]+}} + // CHECK: MANP = 0 : ui{{[0-9]+}} %2 = "FHELinalg.matmul_eint_int"(%0, %1) : (tensor<4x3x!FHE.eint<7>>, tensor<3xi8>) -> tensor<4x!FHE.eint<7>> // =============================== @@ -586,7 +586,7 @@ func.func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 // manp(add_eint(mul, acc)) = 9 // ceil(sqrt(9)) = 3 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -601,7 +601,7 @@ func.func @matmul_int_eint_dyn_p_2(%arg0: tensor<3x2xi3>, %arg1: tensor<2x2x!FHE // manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 // manp(add_eint(mul, acc)) = 9 + 9 = 18 // ceil(sqrt(18)) = 5 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} + // CHECK: MANP = 5 : ui{{[0-9]+}} %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x2xi3>, tensor<2x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -615,7 +615,7 @@ func.func @matmul_int_eint_cst_p_1(%arg0: tensor<1x3x!FHE.eint<2>>) -> tensor<2x // mul = manp(mul_eint_int(eint<2>, 3) = 1^2 + 3^2 = 10 // manp(add_eint(mul, acc)) = 10 // ceil(sqrt(10)) = 4 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %1 = "FHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x1xi3>, tensor<1x3x!FHE.eint<2>>) -> tensor<2x3x!FHE.eint<2>> return %1 : tensor<2x3x!FHE.eint<2>> } @@ -633,7 +633,7 @@ func.func @matmul_int_eint_cst_p_2_n_0(%arg0: tensor<2x3x!FHE.eint<2>>) -> tenso // mul = manp(mul_eint_int(eint<2>, 4) = 1 * 4^2 = 17 // manp(add_eint(mul, acc)) = 17 + 9 = 26 // ceil(sqrt(26)) = 6 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} + // CHECK: MANP = 5 : ui{{[0-9]+}} %1 = "FHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x2xi3>, tensor<2x3x!FHE.eint<2>>) -> tensor<2x3x!FHE.eint<2>> return %1 : tensor<2x3x!FHE.eint<2>> } @@ -651,7 +651,7 @@ func.func @matmul_int_eint_cst_p_2_n_1(%arg0: tensor<2x3x!FHE.eint<2>>) -> tenso // mul = manp(mul_eint_int(eint<2>, 1) = 1 * 1^2 = 1 // manp(add_eint(mul, acc)) = 1 + 17 = 18 // ceil(sqrt(18)) = 5 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} + // CHECK: MANP = 5 : ui{{[0-9]+}} %1 = "FHELinalg.matmul_int_eint"(%0, %arg0): (tensor<2x2xi3>, tensor<2x3x!FHE.eint<2>>) -> tensor<2x3x!FHE.eint<2>> return %1 : tensor<2x3x!FHE.eint<2>> } @@ -821,7 +821,7 @@ func.func @matmul_int_eint_cst_different_operand_manp() -> tensor<2x!FHE.eint<7> %z = "FHE.zero_tensor"() : () -> tensor<3x2x!FHE.eint<7>> %a = arith.constant dense<[[4, 6], [2, 6], [5, 6]]> : tensor<3x2xi8> - // CHECK: {MANP = 1 : ui{{[0-9]+}}} + // CHECK: {MANP = 0 : ui{{[0-9]+}}} %0 = "FHELinalg.add_eint_int"(%z, %a) : (tensor<3x2x!FHE.eint<7>>, tensor<3x2xi8>) -> tensor<3x2x!FHE.eint<7>> // =============================== @@ -831,7 +831,7 @@ func.func @matmul_int_eint_cst_different_operand_manp() -> tensor<2x!FHE.eint<7> [2, 1, 5] > : tensor<3xi8> - // CHECK: MANP = 6 : ui{{[0-9]+}} + // CHECK: MANP = 0 : ui{{[0-9]+}} %2 = "FHELinalg.matmul_int_eint"(%1, %0) : (tensor<3xi8>, tensor<3x2x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> // =============================== diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg_no_canonicalize.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg_no_canonicalize.mlir index 61dcb304e..9b2da8713 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg_no_canonicalize.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg_no_canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --passes fhe-tensor-ops-to-linalg --passes MANP --passes ConcreteOptimizer --action=dump-fhe-no-linalg --split-input-file %s 2>&1 | FileCheck %s func.func @sum() -> !FHE.eint<7> { %0 = "FHE.zero_tensor"() : () -> tensor<5x3x4x2x!FHE.eint<7>> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_matmul.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_matmul.mlir index 94c5b00cf..154d453c0 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_matmul.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_matmul.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --passes fhe-tensor-ops-to-linalg --passes MANP --passes ConcreteOptimizer --action=dump-fhe-no-linalg --split-input-file %s 2>&1 | FileCheck %s func.func @main(%arg0: tensor<1x10x!FHE.eint<33>>) -> tensor<1x1x!FHE.eint<33>> { // sqrt(7282^2 + 20329^2 + 7232^2 + 32768 ^2 + 6446^2 + 32767^2 + 4708^2 + 20050^2 + 28812^2 + 17300^2) = 65277.528491817 diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir index e9e137c11..ca825d68e 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_tensor.mlir @@ -1,9 +1,9 @@ -// RUN: concretecompiler --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --passes fhe-tensor-ops-to-linalg --passes MANP --passes ConcreteOptimizer --action=dump-fhe-no-linalg --split-input-file %s 2>&1 | FileCheck %s func.func @tensor_from_elements_1(%a: !FHE.eint<2>, %b: !FHE.eint<2>, %c: !FHE.eint<2>, %d: !FHE.eint<2>) -> tensor<4x!FHE.eint<2>> { // The MANP value is 1 as all operands are function arguments - // CHECK: %[[ret:.*]] = tensor.from_elements %[[a:.*]], %[[b:.*]], %[[c:.*]], %[[d:.*]] {MANP = 1 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = tensor.from_elements %a, %b, %c, %d : tensor<4x!FHE.eint<2>> return %0 : tensor<4x!FHE.eint<2>> @@ -15,11 +15,11 @@ func.func @tensor_from_elements_2(%a: !FHE.eint<2>, %b: !FHE.eint<2>, %c: !FHE.e { %cst = arith.constant 3 : i3 - // CHECK: %[[V0:.*]] = "FHE.mul_eint_int"(%[[a:.*]], %[[cst:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHE.mul_eint_int"(%a, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> // The MANP value is 3, i.e. the max of all of its operands - // CHECK: %[[V1:.*]] = tensor.from_elements %[[V0]], %[[b:.*]], %[[c:.*]], %[[d:.*]] {MANP = 3 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> + // CHECK: MANP = 3 : ui{{[0-9]+}} %1 = tensor.from_elements %0, %b, %c, %d : tensor<4x!FHE.eint<2>> return %1 : tensor<4x!FHE.eint<2>> @@ -32,7 +32,7 @@ func.func @tensor_extract_1(%t: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2> %cst = arith.constant 1 : index // The MANP value is 1 as the tensor operand is a function argument - // CHECK: %[[ret:.*]] = tensor.extract %[[t:.*]][%[[c1:.*]]] {MANP = 1 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = tensor.extract %t[%cst] : tensor<4x!FHE.eint<2>> return %0 : !FHE.eint<2> @@ -44,9 +44,9 @@ func.func @tensor_extract_2(%a: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2> { %c1 = arith.constant 1 : index %c3 = arith.constant dense<3> : tensor<4xi3> - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> - // CHECK: %[[ret:.*]] = tensor.extract %[[V0]][%[[c3:.*]]] {MANP = 3 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> + // CHECK: MANP = 3 : ui{{[0-9]+}} %2 = tensor.extract %0[%c1] : tensor<4x!FHE.eint<2>> return %2 : !FHE.eint<2> @@ -56,7 +56,7 @@ func.func @tensor_extract_2(%a: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2> func.func @tensor_extract_slice_1(%t: tensor<2x10x!FHE.eint<2>>) -> tensor<1x5x!FHE.eint<2>> { - // CHECK: %[[V0:.*]] = tensor.extract_slice %[[t:.*]][1, 5] [1, 5] [1, 1] {MANP = 1 : ui{{[[0-9]+}}} : tensor<2x10x!FHE.eint<2>> to tensor<1x5x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = tensor.extract_slice %t[1, 5] [1, 5] [1, 1] : tensor<2x10x!FHE.eint<2>> to tensor<1x5x!FHE.eint<2>> return %0 : tensor<1x5x!FHE.eint<2>> @@ -68,10 +68,10 @@ func.func @tensor_extract_slice_2(%a: tensor<4x!FHE.eint<2>>) -> tensor<2x!FHE.e { %c3 = arith.constant dense <3> : tensor<4xi3> - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> - // CHECK: tensor.extract_slice %[[V0]][2] [2] [1] {MANP = 3 : ui{{[0-9]+}}} : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>> + // CHECK: MANP = 3 : ui{{[0-9]+}} %2 = tensor.extract_slice %0[2] [2] [1] : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>> return %2 : tensor<2x!FHE.eint<2>> @@ -81,7 +81,7 @@ func.func @tensor_extract_slice_2(%a: tensor<4x!FHE.eint<2>>) -> tensor<2x!FHE.e func.func @tensor_insert_slice_1(%t0: tensor<2x10x!FHE.eint<2>>, %t1: tensor<2x2x!FHE.eint<2>>) -> tensor<2x10x!FHE.eint<2>> { - // %[[V0:.*]] = tensor.insert_slice %[[t1:.*]] into %[[t0:.*]][0, 5] [2, 2] [1, 1] {MANP = 1 : ui{{[[0-9]+}}} : tensor<2x2x!FHE.eint<2>> into tensor<2x10x!FHE.eint<2>> + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = tensor.insert_slice %t1 into %t0[0, 5] [2, 2] [1, 1] : tensor<2x2x!FHE.eint<2>> into tensor<2x10x!FHE.eint<2>> return %0 : tensor<2x10x!FHE.eint<2>> @@ -90,7 +90,7 @@ func.func @tensor_insert_slice_1(%t0: tensor<2x10x!FHE.eint<2>>, %t1: tensor<2x2 // ----- func.func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8x!FHE.eint<6>> { - // CHECK: tensor.collapse_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}} + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = tensor.collapse_shape %a [[0],[1,2]] : tensor<2x2x4x!FHE.eint<6>> into tensor<2x8x!FHE.eint<6>> return %0 : tensor<2x8x!FHE.eint<6>> } @@ -99,9 +99,9 @@ func.func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8 func.func @tensor_collapse_shape_2(%a: tensor<2x2x4x!FHE.eint<2>>, %b: tensor<2x2x4xi3>) -> tensor<2x8x!FHE.eint<2>> { - // CHECK: "FHELinalg.mul_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%a, %b) : (tensor<2x2x4x!FHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!FHE.eint<2>> - // CHECK-NEXT: tensor.collapse_shape %[[A:.*]] [[X:.*]] {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %1 = tensor.collapse_shape %0 [[0],[1,2]] : tensor<2x2x4x!FHE.eint<2>> into tensor<2x8x!FHE.eint<2>> return %1 : tensor<2x8x!FHE.eint<2>> } @@ -109,7 +109,7 @@ func.func @tensor_collapse_shape_2(%a: tensor<2x2x4x!FHE.eint<2>>, %b: tensor<2x // ----- func.func @tensor_expand_shape_1(%a: tensor<2x8x!FHE.eint<6>>) -> tensor<2x2x4x!FHE.eint<6>> { - // CHECK: tensor.expand_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}} + // CHECK: MANP = 1 : ui{{[0-9]+}} %0 = tensor.expand_shape %a [[0],[1,2]] : tensor<2x8x!FHE.eint<6>> into tensor<2x2x4x!FHE.eint<6>> return %0 : tensor<2x2x4x!FHE.eint<6>> } @@ -118,9 +118,9 @@ func.func @tensor_expand_shape_1(%a: tensor<2x8x!FHE.eint<6>>) -> tensor<2x2x4x! func.func @tensor_expand_shape_2(%a: tensor<2x8x!FHE.eint<2>>, %b: tensor<2x8xi3>) -> tensor<2x2x4x!FHE.eint<2>> { - // CHECK: "FHELinalg.mul_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %0 = "FHELinalg.mul_eint_int"(%a, %b) : (tensor<2x8x!FHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!FHE.eint<2>> - // CHECK-NEXT: tensor.expand_shape %[[A:.*]] [[X:.*]] {MANP = 3 : ui{{[0-9]+}}} + // CHECK: MANP = 3 : ui{{[0-9]+}} %1 = tensor.expand_shape %0 [[0],[1,2]] : tensor<2x8x!FHE.eint<2>> into tensor<2x2x4x!FHE.eint<2>> return %1 : tensor<2x2x4x!FHE.eint<2>> } diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/optimizer_ast.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/optimizer_ast.mlir index 7dc46a8f5..2ff224951 100644 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/optimizer_ast.mlir +++ b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHE/optimizer_ast.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --verbose --split-input-file --action=dump-fhe %s 2>&1| FileCheck %s +// RUN: concretecompiler --verbose --passes canonicalize --passes MANP --passes ConcreteOptimizer --split-input-file --action=dump-fhe-no-linalg %s 2>&1| FileCheck %s func.func @main(%arg0: tensor<5x!FHE.eint<5>>) -> !FHE.eint<5> { %weights = arith.constant dense<[-1, -1, -1, -1, -1]> : tensor<5xi6> diff --git a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tensor-ops-to-linalg.mlir b/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tensor-ops-to-linalg.mlir deleted file mode 100644 index 4d177017c..000000000 --- a/compilers/concrete-compiler/compiler/tests/check_tests/Dialect/FHELinalg/tensor-ops-to-linalg.mlir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: concretecompiler %s --action=dump-fhe-no-linalg 2>&1 | FileCheck %s - -// CHECK: module { -// CHECK-NEXT: func.func @dot_eint_int(%[[Varg0:.*]]: tensor<2x!FHE.eint<2>>, %[[Varg1:.*]]: tensor<2xi3>) -> !FHE.eint<2> { -// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[Vc2:.*]] = arith.constant 2 : index -// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index -// CHECK-NEXT: %[[V0:.*]] = "FHE.zero_tensor"() : () -> tensor<1x!FHE.eint<2>> -// CHECK-NEXT: %[[V1:.*]] = scf.for %[[Varg2:.*]] = %[[Vc0]] to %[[Vc2]] step %[[Vc1]] iter_args(%[[Varg3:.*]] = %[[V0]]) -> (tensor<1x!FHE.eint<2>>) { -// CHECK-NEXT: %[[V3:.*]] = tensor.extract %[[Varg0]]{{\[}}%[[Varg2]]{{\]}} : tensor<2x!FHE.eint<2>> -// CHECK-NEXT: %[[V4:.*]] = tensor.extract %[[Varg1]]{{\[}}%[[Varg2]]{{\]}} : tensor<2xi3> -// CHECK-NEXT: %[[V5:.*]] = tensor.extract %[[Varg3]]{{\[}}%[[Vc0]]{{\]}} : tensor<1x!FHE.eint<2>> -// CHECK-NEXT: %[[V6:.*]] = "FHE.mul_eint_int"(%[[V3]], %[[V4]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2> -// CHECK-NEXT: %[[V7:.*]] = "FHE.add_eint"(%[[V6]], %[[V5]]) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> -// CHECK-NEXT: %[[V8:.*]] = tensor.insert %[[V7]] into %[[Varg3]]{{\[}}%[[Vc0]]{{\]}} : tensor<1x!FHE.eint<2>> -// CHECK-NEXT: scf.yield %[[V8]] : tensor<1x!FHE.eint<2>> -// CHECK-NEXT: } -// CHECK-NEXT: %[[V2:.*]] = tensor.extract %[[V1]]{{\[}}%[[Vc0]]{{\]}} : tensor<1x!FHE.eint<2>> -// CHECK-NEXT: return %[[V2]] : !FHE.eint<2> -// CHECK-NEXT: } -// CHECK-NEXT: } -func.func @dot_eint_int(%arg0: tensor<2x!FHE.eint<2>>, - %arg1: tensor<2xi3>) -> !FHE.eint<2> -{ - %o = "FHELinalg.dot_eint_int"(%arg0, %arg1) : - (tensor<2x!FHE.eint<2>>, tensor<2xi3>) -> !FHE.eint<2> - return %o : !FHE.eint<2> -}