diff --git a/compiler/Makefile b/compiler/Makefile index b5231e0fe..6b8af6d1e 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -240,10 +240,6 @@ release_tarballs: update_python_version: echo "__version__ = \"`git describe --tags --abbrev=0 | grep -e '[0-9].*' -o`\"" > lib/Bindings/Python/version.txt -generate_conv_op: - python -m mlir.dialects.linalg.opdsl.dump_oplib ops.core_named_ops > ops/LinalgNamedStructuredOps.yaml - $(BUILD_DIR)/bin/mlir-linalg-ods-yaml-gen ops/LinalgNamedStructuredOps.yaml --o-impl=ops/LinalgOps.cpp --o-ods-decl=ops/LinalgNamedStructuredOps.yamlgen.td - check_python_format: black --check tests/python/ lib/Bindings/Python/concrete/ @@ -266,7 +262,6 @@ python_lint: package_py310 \ release_tarballs \ update_python_version \ - generate_conv_op \ python_lint \ python_format \ check_python_format \ diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h index 35d1255f3..fee8c6094 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h @@ -97,10 +97,6 @@ public: } // namespace mlir #define GET_OP_CLASSES -// TODO: remove this when removing the custom linalg op for Conv -// the generated code was calling functions from the mlir::linalg namespace -using namespace mlir::linalg; -// END TODO #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc" #endif diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 04875e46a..447c77ca0 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -3,9 +3,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Dialect/Linalg/IR/LinalgBase.td" -include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" -include "mlir/Interfaces/InferTypeOpInterface.td" include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td" include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.td" @@ -946,151 +943,6 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> { let hasVerifier = 1; } -class LinalgStructuredBase_Op props> - : Op, - DeclareOpInterfaceMethods, - LinalgStructuredInterface, - ReifyRankedShapedTypeOpInterface], props)> { - code structuredOpsBaseDecls = [{ - // Return whether the op accesses the iteration indices. - bool hasIndexSemantics() { - return !this->getBody()->getOps().empty(); - } - - LogicalResult reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()).reifyResultShapes(b, - reifiedReturnShapes); - } - }]; -} - -def FhelinalgConv2DNchwFchwOp : LinalgStructuredBase_Op<"fhelinalg_conv_2d_nchw_fchw", !listconcat([AttrSizedOperandSegments], - /*extraInterfaces=*/[LinalgConvolutionOpInterface])> { - - let cppNamespace = "mlir::concretelang::FHELinalg"; - let summary = [{ Performs 2-D convolution. }]; - let description = [{ - Layout: - * Input: NCHW. - * Kernel: FCHW. - -Numeric casting is performed on the operands to the inner multiply, promoting -them to the same data type as the accumulator/output. - }]; - - let arguments = (ins - Variadic:$inputs, - Variadic:$outputs, -RankedI64ElementsAttr<[2]>:$strides, -RankedI64ElementsAttr<[2]>:$dilations - ); - let results = (outs Variadic:$result_tensors); - let regions = (region AnyRegion:$region); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder< - (ins "ValueRange":$inputs, "ValueRange":$outputs, - CArg<"ArrayRef", "{}">:$attributes), - [{ - $_state.addOperands(inputs); - $_state.addOperands(outputs); - SmallVector resultTensorTypes; - copy_if(outputs.getTypes(), - std::back_inserter(resultTensorTypes), - [](Type type) { return type.isa(); }); - $_state.addTypes(resultTensorTypes); - $_state.addAttribute( - "operand_segment_sizes", - $_builder.getI32VectorAttr({ - static_cast(inputs.size()), - static_cast(outputs.size())})); - $_state.addAttributes(attributes); - createAndFillStructuredOpRegion( - $_builder, - $_state, - TypeRange(inputs), - TypeRange(outputs)); - }]>, - OpBuilder< - (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, - "ValueRange":$outputs, - CArg<"ArrayRef", "{}">:$attributes), - [{ - $_state.addOperands(inputs); - $_state.addOperands(outputs); - $_state.addTypes(resultTensorTypes); - $_state.addAttributes(attributes); - $_state.addAttribute( - "operand_segment_sizes", - $_builder.getI32VectorAttr({ - static_cast(inputs.size()), - static_cast(outputs.size())})); - createAndFillStructuredOpRegion( - $_builder, - $_state, - TypeRange(inputs), - TypeRange(outputs)); - }]>, - OpBuilder< - (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes), - [{ - $_state.addOperands(operands); - $_state.addAttributes(attributes); - $_state.addTypes(resultTensorTypes); - (void)$_state.addRegion(); - }]> - - , OpBuilder< - (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, - "ValueRange":$outputs, "Attribute":$strides, "Attribute":$dilations, - CArg<"ArrayRef", "{}">:$attributes), - [{ - $_state.addOperands(inputs); - $_state.addOperands(outputs); - $_state.addTypes(resultTensorTypes); - $_state.addAttribute( - "operand_segment_sizes", - $_builder.getI32VectorAttr({ - static_cast(inputs.size()), - static_cast(outputs.size())})); - createAndFillStructuredOpRegion( - $_builder, - $_state, - TypeRange(inputs), - TypeRange(outputs)); - $_state.addAttribute("strides", strides); -$_state.addAttribute("dilations", dilations); - $_state.addAttributes(attributes); - }]> - - ]; - let hasCustomAssemblyFormat = 1; - let hasFolder = 1; - - let extraClassDeclaration = structuredOpsBaseDecls # [{ - // Auto-generated. - ArrayAttr iterator_types(); - ArrayAttr indexing_maps(); - static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, llvm::ArrayRef); - static std::function)> - getRegionBuilder() { - return regionBuilder; - } - - // Generic methods. - static unsigned getNumRegionArgs(); - std::string getLibraryCallName(); - - bool hasDynamicIndexingMaps(); - LogicalResult verifyIndexingMapRequiredAttributes(); - - }]; -} - def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> { let summary = "Returns a tensor that contains the transposition of the input tensor."; diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 39b9f72b0..ee82d3216 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -6,13 +6,9 @@ #include #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Parser/Parser.h" -#include "llvm/Support/FormatVariadic.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" @@ -1019,688 +1015,6 @@ mlir::LogicalResult FromElementOp::verify() { return mlir::success(); } -//===----------------------------------------------------------------------===// -// Implementation of FhelinalgConv2DNchwFchwOp -// This is a generated functions from `make generate_conv_op`, and some helpers -// from LinalgOps.cpp -//===----------------------------------------------------------------------===// -using namespace mlir; -using namespace mlir::linalg; - -ArrayAttr FhelinalgConv2DNchwFchwOp::iterator_types() { - return Builder(getContext()) - .getStrArrayAttr(SmallVector{ - getParallelIteratorTypeName(), getParallelIteratorTypeName(), - getParallelIteratorTypeName(), getParallelIteratorTypeName(), - getReductionIteratorTypeName(), getReductionIteratorTypeName(), - getReductionIteratorTypeName()}); -} - -static SmallVector -getSymbolBindings(FhelinalgConv2DNchwFchwOp self) { - MLIRContext *context = self.getContext(); - SmallVector exprs; - exprs.push_back(getAffineSymbolExpr(0, context)); - exprs.push_back(getAffineSymbolExpr(1, context)); - exprs.push_back(getAffineSymbolExpr(2, context)); - - int64_t cst3 = self.strides().getValues()[{0}]; - exprs.push_back(getAffineConstantExpr(cst3, context)); - - exprs.push_back(getAffineSymbolExpr(4, context)); - - int64_t cst5 = self.dilations().getValues()[{0}]; - exprs.push_back(getAffineConstantExpr(cst5, context)); - - exprs.push_back(getAffineSymbolExpr(6, context)); - - int64_t cst7 = self.strides().getValues()[{1}]; - exprs.push_back(getAffineConstantExpr(cst7, context)); - - exprs.push_back(getAffineSymbolExpr(8, context)); - - int64_t cst9 = self.dilations().getValues()[{1}]; - exprs.push_back(getAffineConstantExpr(cst9, context)); - - exprs.push_back(getAffineSymbolExpr(10, context)); - return exprs; -} - -ArrayAttr FhelinalgConv2DNchwFchwOp::indexing_maps() { - static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; - ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); - if (cached) - return cached; - - MLIRContext *context = getContext(); - auto symbolBindings = getSymbolBindings(*this); - SmallVector maps; - maps.push_back( - mlir::parseAttribute( - "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, " - "s7, s8, s9, s10] -> (d0, d4, d2 * s3 + d5 * s5, d3 * s7 + d6 * s9)>", - context) - .cast() - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); - maps.push_back(mlir::parseAttribute( - "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, " - "s4, s5, s6, s7, s8, s9, s10] -> (d1, d4, d5, d6)>", - context) - .cast() - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); - maps.push_back(mlir::parseAttribute( - "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, " - "s4, s5, s6, s7, s8, s9, s10] -> (d0, d1, d2, d3)>", - context) - .cast() - .getValue()); - maps.back() = simplifyAffineMap( - maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); - cached = Builder(context).getAffineMapArrayAttr(maps); - getOperation()->setAttr(memoizeAttr, cached); - return cached; -} - -unsigned FhelinalgConv2DNchwFchwOp::getNumRegionArgs() { return 3; } - -std::string FhelinalgConv2DNchwFchwOp::getLibraryCallName() { - return generateLibraryCallName(getOperation()); -} - -bool FhelinalgConv2DNchwFchwOp::hasDynamicIndexingMaps() { return true; } -LogicalResult FhelinalgConv2DNchwFchwOp::verifyIndexingMapRequiredAttributes() { - Operation *op = getOperation(); - - if (auto attr = op->getAttrOfType("strides")) { - if (!attr.getType().getElementType().isInteger(64)) - return op->emitError("incorrect element type for indexing map required " - "attribute 'strides'"); - if (attr.getType().getShape() != ArrayRef{2}) - return op->emitError( - "incorrect shape for indexing map required attribute 'strides'"); - } else { - return op->emitError("missing indexing map required attribute 'strides'"); - } - - if (auto attr = op->getAttrOfType("dilations")) { - if (!attr.getType().getElementType().isInteger(64)) - return op->emitError("incorrect element type for indexing map required " - "attribute 'dilations'"); - if (attr.getType().getShape() != ArrayRef{2}) - return op->emitError( - "incorrect shape for indexing map required attribute 'dilations'"); - } else { - return op->emitError("missing indexing map required attribute 'dilations'"); - } - - return success(); -} - -// Copied from LinalgOps.cpp; license is: Apache License v2.0 with -// LLVM Exceptions -using RegionBuilderFn = llvm::function_ref)>; - -// Copied from LinalgOps.cpp; license is: Apache License v2.0 with -// LLVM Exceptions -static void printNamedStructuredOpResults(OpAsmPrinter &p, - TypeRange resultTypes) { - if (resultTypes.empty()) - return; - p.printOptionalArrowTypeList(resultTypes); -} - -// Copied from LinalgOps.cpp; license is: Apache License v2.0 with -// LLVM Exceptions -static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, - ValueRange outputs) { - if (!inputs.empty()) - p << " ins(" << inputs << " : " << inputs.getTypes() << ")"; - if (!outputs.empty()) - p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; -} - -// Copied from LinalgOps.cpp; license is: Apache License v2.0 with -// LLVM Exceptions -static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, - ValueRange inputs, ValueRange outputs) { - p.printOptionalAttrDict( - op->getAttrs(), - /*elidedAttrs=*/{"operand_segment_sizes", - // See generated code in mlir-linalg-yaml-gen.cpp - "linalg.memoized_indexing_maps"}); - - // Printing is shared with generic ops, except for the region and - // attributes. - printCommonStructuredOpParts(p, inputs, outputs); - - // Results printing. - printNamedStructuredOpResults(p, op->getResultTypes()); - - // Region is elided. -} - -void FhelinalgConv2DNchwFchwOp::print(mlir::OpAsmPrinter &p) { - printNamedStructuredOp(p, this->getOperation(), - this->getOperation()->getOperands(), - this->getOperation()->getResults()); -} - -// Copied from LinalgOps.cpp; license is: Apache License v2.0 with -// LLVM Exceptions -static ParseResult -parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, - SmallVectorImpl &inputTypes, - SmallVectorImpl &outputTypes) { - SMLoc inputsOperandsLoc, outputsOperandsLoc; - SmallVector inputsOperands, - outputsOperands; - - parser.parseOptionalAttrDict(result.attributes); - - if (succeeded(parser.parseOptionalKeyword("ins"))) { - if (parser.parseLParen()) - return failure(); - - inputsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseOperandList(inputsOperands) || - parser.parseColonTypeList(inputTypes) || parser.parseRParen()) - return failure(); - } - - if (succeeded(parser.parseOptionalKeyword("outs"))) { - outputsOperandsLoc = parser.getCurrentLocation(); - if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || - parser.parseColonTypeList(outputTypes) || parser.parseRParen()) - return failure(); - } - - if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, - result.operands) || - parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, - result.operands)) - return failure(); - - result.addAttribute("operand_segment_sizes", - parser.getBuilder().getI32VectorAttr( - {static_cast(inputsOperands.size()), - static_cast(outputsOperands.size())})); - return success(); -} - -// Copied from LinalgOps.cpp; license is: Apache License v2.0 with -// LLVM Exceptions -static ParseResult -parseNamedStructuredOpResults(OpAsmParser &parser, - SmallVectorImpl &resultTypes) { - if (parser.parseOptionalArrowTypeList(resultTypes)) - return failure(); - return success(); -} - -// Copied from LinalgOps.cpp; license is: Apache License v2.0 with -// LLVM Exceptions -static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes, - ArrayRef attrs, - RegionBuilderFn regionBuilder) { - assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); - - // TODO: atm all operands go through getElementTypeOrSelf, - // reconsider when we have evidence we need to. - SmallVector argTypes; - SmallVector argLocs; - for (auto containers : {inputTypes, outputTypes}) { - for (auto t : containers) { - argTypes.push_back(getElementTypeOrSelf(t)); - - // TODO: Pass in a proper location here. - argLocs.push_back(opBuilder.getUnknownLoc()); - } - } - - // RAII. - OpBuilder::InsertionGuard guard(opBuilder); - Block *body = - opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); - - opBuilder.setInsertionPointToStart(body); - ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); - regionBuilder(b, *body, attrs); - - // indexing_maps is an auto-generated method. - - // iterator_types is an auto-generated method. -} - -// Copied from LinalgOps.cpp; license is: Apache License v2.0 with -// LLVM Exceptions -static ParseResult parseNamedStructuredOpRegion( - OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, - TypeRange inputTypes, TypeRange outputTypes, ArrayRef attrs, - RegionBuilderFn regionBuilder) { - if (numRegionArgs != inputTypes.size() + outputTypes.size()) { - return parser.emitError( - parser.getCurrentLocation(), - llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " - "region expects {0} args, got {1}", - numRegionArgs, inputTypes.size() + outputTypes.size())); - } - - OpBuilder opBuilder(parser.getContext()); - fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs, - regionBuilder); - return success(); -} - -// Copied from LinalgOps.cpp; license is: Apache License v2.0 with -// LLVM Exceptions -static ParseResult parseNamedStructuredOp(OpAsmParser &parser, - OperationState &result, - unsigned numRegionArgs, - RegionBuilderFn regionBuilder) { - // TODO: Enable when ods-gen supports captures. - SmallVector inputTypes, outputTypes; - if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) - return failure(); - - // TODO: consider merging results parsing into region parsing. - // Need to wait for declarative assembly resolution to decide. - SmallVector outputTensorsTypes; - if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) - return failure(); - result.addTypes(outputTensorsTypes); - - std::unique_ptr region = std::make_unique(); - if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes, - outputTypes, result.attributes.getAttrs(), - regionBuilder)) - return failure(); - result.addRegion(std::move(region)); - - return success(); -} - -mlir::ParseResult -FhelinalgConv2DNchwFchwOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - return parseNamedStructuredOp(parser, result, getNumRegionArgs(), - getRegionBuilder()); -} - -/// Some helpers were copied from LinalgOps.cpp - -/// Generic entry point to create the block for the region of a LinalgOp. -/// This is used by both named structured ops created by ods-gen and by manually -/// defined C++ ops. -/// This is used by both builders and parsers. -/// This function creates the block in the region with arguments corresponding -/// to the elemental types of `inputTypes` and `outputTypes`. The latter are -/// asserted to be of ShapedType. -template -static void fillStructuredOpRegion( - OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, - TypeRange outputTypes, - std::function errorHandler = nullptr); - -/// Generic entry point to create both the region and the block of a LinalgOp. -template -static void -createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, - TypeRange inputTypes, TypeRange outputTypes); - -/// Common parsing and printing used for both named structured ops created by -/// ods-gen and by manually defined C++ ops. Does not handle regions. -static ParseResult -parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, - SmallVectorImpl &inputTypes, - SmallVectorImpl &outputTypes); -template -static void printCommonStructuredOpParts(OpAsmPrinter &p, - NamedStructuredOpType op); - -/// Specific parsing and printing for named structured ops created by ods-gen. -template -static ParseResult -parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes); - -static ParseResult -parseNamedStructuredOpResults(OpAsmParser &parser, - SmallVectorImpl &resultTypes); - -template -static ParseResult parseNamedStructuredOp(OpAsmParser &parser, - OperationState &result); - -static void printNamedStructuredOpResults(OpAsmPrinter &p, - TypeRange resultTypes); - -template -static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); - -class RegionBuilderHelper { -public: - RegionBuilderHelper(MLIRContext *context, Block &block) - : context(context), block(block) {} - - // Generates operations to cast the given operand to a specified type. - // If the cast cannot be performed, a warning will be issued and the - // operand returned as-is (which will presumably yield a verification - // issue downstream). - Value cast(Type toType, Value operand, bool isUnsignedCast) { - OpBuilder builder = getBuilder(); - auto loc = operand.getLoc(); - - if (operand.getType() == toType) - return operand; - if (auto toIntType = toType.dyn_cast()) { - // If operand is floating point, cast directly to the int type. - if (operand.getType().isa()) { - if (isUnsignedCast) - return builder.create(loc, toType, operand); - return builder.create(loc, toType, operand); - } - // Cast index operands directly to the int type. - if (operand.getType().isIndex()) - return builder.create(loc, toType, operand); - if (auto fromIntType = operand.getType().dyn_cast()) { - // Either extend or truncate. - if (toIntType.getWidth() > fromIntType.getWidth()) { - if (isUnsignedCast) - return builder.create(loc, toType, operand); - return builder.create(loc, toType, operand); - } - if (toIntType.getWidth() < fromIntType.getWidth()) - return builder.create(loc, toType, operand); - } - } else if (auto toFloatType = toType.dyn_cast()) { - // If operand is integer, cast directly to the float type. - // Note that it is unclear how to cast from BF16<->FP16. - if (operand.getType().isa()) { - if (isUnsignedCast) - return builder.create(loc, toFloatType, operand); - return builder.create(loc, toFloatType, operand); - } - if (auto fromFloatType = operand.getType().dyn_cast()) { - if (toFloatType.getWidth() > fromFloatType.getWidth()) - return builder.create(loc, toFloatType, operand); - if (toFloatType.getWidth() < fromFloatType.getWidth()) - return builder.create(loc, toFloatType, operand); - } - } - - emitWarning(operand.getLoc()) << "could not cast operand of type " - << operand.getType() << " to " << toType; - return operand; - } - - Value applyfn__add(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs) && - rhs.getType().isa()) { - return builder.create(lhs.getLoc(), - rhs, lhs); - } - if (lhs.getType().isa() && - isInteger(rhs)) { - return builder.create(lhs.getLoc(), - lhs, rhs); - } - if (lhs.getType().isa() && - rhs.getType().isa()) { - return builder.create(lhs.getLoc(), - lhs, rhs); - } - llvm_unreachable("unsupported non numeric type"); - } - - Value applyfn__exp(Value x) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(x)) - return builder.create(x.getLoc(), x); - llvm_unreachable("unsupported non numeric type"); - } - - Value applyfn__log(Value x) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(x)) - return builder.create(x.getLoc(), x); - llvm_unreachable("unsupported non numeric type"); - } - - Value applyfn__sub(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - Value applyfn__mul(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (lhs.getType().isa() && - isInteger(rhs)) { - return builder.create(lhs.getLoc(), - lhs, rhs); - } - llvm_unreachable("unsupported non numeric type"); - } - - Value applyfn__max(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - Value applyfn__max_unsigned(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - Value applyfn__min(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - Value applyfn__min_unsigned(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - void yieldOutputs(ValueRange values) { - assert(!values.empty() && "linalg ops must yield outputs"); - if (values.empty()) - return; - Value first = values.front(); - OpBuilder builder = getBuilder(); - builder.create(first.getLoc(), values); - } - - Value constant(std::string value) { - OpBuilder builder = getBuilder(); - Location loc = builder.getUnknownLoc(); - Attribute valueAttr = parseAttribute(value, builder.getContext()); - return builder.create(loc, valueAttr.getType(), - valueAttr); - } - - Value index(int64_t dim) { - OpBuilder builder = getBuilder(); - return builder.create(builder.getUnknownLoc(), dim); - } - - Type getIntegerType(unsigned width) { - return IntegerType::get(context, width); - } - - Type getFloat32Type() { return Float32Type::get(context); } - - Type getFloat64Type() { return Float64Type::get(context); } - -private: - MLIRContext *context; - Block █ - - bool isFloatingPoint(Value value) { return value.getType().isa(); } - bool isInteger(Value value) { return value.getType().isa(); } - - OpBuilder getBuilder() { - OpBuilder builder(context); - builder.setInsertionPointToEnd(&block); - return builder; - } -}; - -static LogicalResult foldMemRefCast(Operation *op) { - bool folded = false; - for (OpOperand &operand : op->getOpOperands()) { - auto castOp = operand.get().getDefiningOp(); - if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { - operand.set(castOp.getOperand()); - folded = true; - } - } - return success(folded); -} - -/// Generic entry point to create the block for the region of a LinalgOp. -/// This is used by both named structured ops created by ods-gen and by manually -/// defined C++ ops. -/// This is used by both builders and parsers. -/// This function creates the block in the region with arguments corresponding -/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted -/// to be ShapedType. -template -static void -fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, - TypeRange inputTypes, TypeRange outputTypes, - std::function errorHandler) { - assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); - - // TODO: atm all operands go through getElementTypeOrSelf, - // reconsider when we have evidence we need to. - SmallVector argTypes; - SmallVector argLocs; - for (auto containers : {inputTypes, outputTypes}) - for (auto t : containers) { - argTypes.push_back(getElementTypeOrSelf(t)); - argLocs.push_back(opBuilder.getUnknownLoc()); - } - - // RAII. - OpBuilder::InsertionGuard guard(opBuilder); - Block *body = - opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); - unsigned actual = body->getNumArguments(); - unsigned expected = NamedStructuredOpType::getNumRegionArgs(); - if (expected != actual) { - if (errorHandler) - errorHandler(expected, actual); - return; - } - - opBuilder.setInsertionPointToStart(body); - ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); - NamedStructuredOpType::regionBuilder(b, *body, {}); - - // indexing_maps is an auto-generated method. - - // iterator_types is an auto-generated method. -} - -static void getGenericEffectsImpl( - SmallVectorImpl> - &effects, - ValueRange results, ValueRange inputBuffers, ValueRange outputs) { - for (Value value : results) { - effects.emplace_back(MemoryEffects::Allocate::get(), value, - SideEffects::DefaultResource::get()); - } - for (Value value : inputBuffers) { - effects.emplace_back(MemoryEffects::Read::get(), value, - SideEffects::DefaultResource::get()); - } - for (Value value : outputs) { - effects.emplace_back(MemoryEffects::Read::get(), value, - SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, - SideEffects::DefaultResource::get()); - } -} - -/// Generic entry point to create both the region and the block of a LinalgOp. -template -void createAndFillStructuredOpRegion(OpBuilder &opBuilder, - OperationState &result, - TypeRange inputTypes, - TypeRange outputTypes) { - Region ®ion = *result.addRegion(); - fillStructuredOpRegion( - opBuilder, region, inputTypes, outputTypes, - [&](unsigned expected, unsigned actual) { - assert(expected != actual && "incorrect number of arguments"); - }); -} - -/// END OF COPY FROM LinalgOps.cpp - -void FhelinalgConv2DNchwFchwOp::regionBuilder( - ImplicitLocOpBuilder &b, Block &block, - llvm::ArrayRef) { - assert(3 > 0 && block.getNumArguments() == 3 && - "FhelinalgConv2DNchwFchwOp regionBuilder expects 3 (>=0) args"); - RegionBuilderHelper helper(block.getArgument(0).getContext(), block); - SmallVector yields; - Value value1 = - helper.cast(block.getArgument(0).getType(), block.getArgument(0), false); - Value value2 = - helper.cast(block.getArgument(1).getType(), block.getArgument(1), false); - Value value3 = helper.applyfn__mul(value1, value2); - Value value4 = helper.applyfn__add(block.getArgument(2), value3); - yields.push_back(value4); - helper.yieldOutputs(yields); -} - -LogicalResult FhelinalgConv2DNchwFchwOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -void FhelinalgConv2DNchwFchwOp::getEffects( - SmallVectorImpl> - &effects) { - SmallVector inputBuffers = getInputBufferOperands(); - SmallVector outputBuffers = getOutputBufferOperands(); - getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, - outputBuffers); -} - /// Verify the transpose shapes mlir::LogicalResult TransposeOp::verify() { mlir::Type tensorTy = ((mlir::Type)this->tensor().getType()); diff --git a/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp b/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp index 713ee68d4..bc0787e9a 100644 --- a/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp +++ b/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp @@ -50,7 +50,7 @@ static bool isCandidateForTask(Operation *op) { FHELinalg::ApplyLookupTableEintOp, FHELinalg::ApplyMultiLookupTableEintOp, FHELinalg::ApplyMappedLookupTableEintOp, FHELinalg::Dot, FHELinalg::MatMulEintIntOp, FHELinalg::MatMulIntEintOp, FHELinalg::SumOp, - FHELinalg::ConcatOp, FHELinalg::FhelinalgConv2DNchwFchwOp>(op); + FHELinalg::ConcatOp>(op); } /// Identify operations that are beneficial to sink into tasks. These diff --git a/compiler/ops/core_named_ops.py b/compiler/ops/core_named_ops.py deleted file mode 100644 index b577c20f7..000000000 --- a/compiler/ops/core_named_ops.py +++ /dev/null @@ -1,29 +0,0 @@ -from mlir.dialects.linalg.opdsl.lang import * - -T1 = TV.T1 -T2 = TV.T2 - -Batch = S.Batch - - -@linalg_structured_op -def fhelinalg_conv_2d_nchw_fchw( - I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.F, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): - """Performs 2-D convolution. - - Layout: - * Input: NCHW. - * Kernel: FCHW. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.f, D.oh, D.ow] += cast( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW - ]) * cast(U, K[D.f, D.c, D.kh, D.kw])