// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #ifndef CONCRETELANG_SUPPORT_LINALG_EXTRAS_H_ #define CONCRETELANG_SUPPORT_LINALG_EXTRAS_H_ #include #include #include #include #include #include namespace mlir { namespace concretelang { namespace linalgextras { using namespace mlir; using namespace mlir::linalg; static SmallVector makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, ArrayRef vals) { if (map.isEmpty()) return {}; assert(map.getNumInputs() == vals.size()); SmallVector res; res.reserve(map.getNumResults()); auto dims = map.getNumDims(); for (auto e : map.getResults()) { auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); SmallVector operands(vals.begin(), vals.end()); canonicalizeMapAndOperands(&exprMap, &operands); res.push_back(b.create(loc, exprMap, operands)); } return res; } template static std::vector inlineRegionAndEmitStore( OpBuilder &b, Location loc, OpType op, ArrayRef indexedValues, ArrayRef> indexing, ArrayRef outputBuffers) { auto &block = op->getRegion(0).front(); BlockAndValueMapping map; map.map(block.getArguments(), indexedValues); for (auto &op : block.without_terminator()) { auto *newOp = b.clone(op, map); map.map(op.getResults(), newOp->getResults()); } Operation *terminator = block.getTerminator(); std::vector retVals; for (OpOperand &operand : terminator->getOpOperands()) { Value toStore = map.lookupOrDefault(operand.get()); Value newTens = b.create( loc, toStore, outputBuffers[operand.getOperandNumber()], indexing[operand.getOperandNumber()]); retVals.push_back(newTens); } return retVals; } /// Replace the index operations in the body of the loop nest by the matching /// induction variables. static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp, PatternRewriter &rewriter, ArrayRef loopOps) { // Extract the induction variables of the loop nest from outer to inner. SmallVector allIvs; for (Operation *loopOp : loopOps) { llvm::TypeSwitch(loopOp) .Case([&](scf::ParallelOp parallelOp) { allIvs.append(parallelOp.getInductionVars().begin(), parallelOp.getInductionVars().end()); }) .Case([&](scf::ForOp forOp) { allIvs.push_back(forOp.getInductionVar()); }) .Case([&](AffineForOp affineForOp) { allIvs.push_back(affineForOp.getInductionVar()); }) .Default([&](Operation *op) { assert(false && "unexpected op"); }); } assert(linalgOp.getNumLoops() == allIvs.size() && "expected the number of loops and induction variables to match"); // Replace the index operations in the body of the innermost loop op. if (!loopOps.empty()) { LoopLikeOpInterface loopOp = loopOps.back(); for (IndexOp indexOp : llvm::make_early_inc_range(loopOp.getLoopBody().getOps())) rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]); } } template static std::vector emitScalarImplementation(OpBuilder &b, Location loc, ArrayRef allIvs, LinalgOp linalgOp, ValueRange operandValuesToUse) { assert(linalgOp.hasTensorSemantics() && "expected linalg op with buffer semantics"); SmallVector indexedValues; indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); auto allIvsPlusDims = SmallVector(allIvs.begin(), allIvs.end()); // TODO: Avoid the loads if the corresponding argument of the // region has no uses. // 1.a. Emit load from input operand or for scalars access the operand itself. for (OpOperand *inputOperand : linalgOp.getInputOperands()) { if (linalgOp.isScalar(inputOperand)) { indexedValues.push_back(inputOperand->get()); continue; } auto indexing = makeCanonicalAffineApplies( b, loc, linalgOp.getTiedIndexingMap(inputOperand), allIvsPlusDims); indexedValues.push_back( b.create(loc, inputOperand->get(), indexing)); } // 1.b. Emit load from output views. for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { SmallVector indexing = makeCanonicalAffineApplies( b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims); indexedValues.push_back( b.create(loc, outputOperand->get(), indexing)); } // TODO: When a region inliner exists, use it. // 2. Inline region, currently only works for a single basic block. // 3. Emit store. SmallVector, 8> indexing; SmallVector outputBuffers; for (OpOperand *outputOperand : linalgOp.getOutputTensorOperands()) { indexing.push_back(makeCanonicalAffineApplies( b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims)); outputBuffers.push_back(operandValuesToUse.back()); } return inlineRegionAndEmitStore( b, loc, linalgOp, indexedValues, indexing, outputBuffers); } template static FailureOr linalgTensorOpToLoopsImpl(PatternRewriter &rewriter, LinalgOp linalgOp, bool parallelizeLoops) { // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). assert(linalgOp.hasTensorSemantics() && "expected linalg op with value semantics"); auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); SmallVector allIvs; GenerateLoopNest::doit( rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange operandValuesToUse) -> scf::ValueVector { // assert(operandValuesToUse == linalgOp->getOperands() && // "expect operands are captured and not passed by loop // argument"); allIvs.append(ivs.begin(), ivs.end()); return emitScalarImplementation( b, loc, allIvs, linalgOp, operandValuesToUse); // return scf::ValueVector{}; }); // Number of loop ops might be different from the number of ivs since some // loops like affine.parallel and scf.parallel have multiple ivs. SetVector loopSet; for (Value iv : allIvs) { if (!iv) return failure(); // The induction variable is a block argument of the entry block of the // loop operation. BlockArgument ivVal = iv.dyn_cast(); if (!ivVal) return failure(); loopSet.insert(ivVal.getOwner()->getParentOp()); } LinalgLoops loops(loopSet.begin(), loopSet.end()); // Just mark loop with a parallel attributes if (parallelizeLoops) { for (auto loop : llvm::enumerate(loops)) { loop.value()->setAttr("parallel", rewriter.getBoolAttr(isParallelIterator( iteratorTypes[loop.index()]))); } } // Replace all index operations in the loop body. replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops); return loops; } } // namespace linalgextras } // namespace concretelang } // namespace mlir #endif