From 78596f899fa2b38150f60bedc176274147427b4f Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 9 Feb 2022 14:41:33 +0100 Subject: [PATCH] feat: add generated linalg conv operation This has been generated using linalg tools, then put in their appropriate locations. This is intended as a workaround since linalg doesn't support tensors of custom types yet. Any conversion using this added operation should be able to use the default operation from linalg when it starts supporting tensor of custom types. --- compiler/Makefile | 8 +- .../Dialect/FHELinalg/IR/FHELinalgOps.h | 5 + .../Dialect/FHELinalg/IR/FHELinalgOps.td | 152 +++++ .../lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp | 619 ++++++++++++++++++ compiler/ops/core_named_ops.py | 29 + 5 files changed, 812 insertions(+), 1 deletion(-) create mode 100644 compiler/ops/core_named_ops.py diff --git a/compiler/Makefile b/compiler/Makefile index ed9e99028..5be572c83 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -203,6 +203,11 @@ release_tarballs: update_python_version: echo "__version__ = \"`git describe --tags --abbrev=0 | grep -e '[0-9].*' -o`\"" > lib/Bindings/Python/version.txt +generate_conv_op: + python -m mlir.dialects.linalg.opdsl.dump_oplib ops.core_named_ops > ops/LinalgNamedStructuredOps.yaml + $(BUILD_DIR)/bin/mlir-linalg-ods-yaml-gen ops/LinalgNamedStructuredOps.yaml --o-impl=ops/LinalgOps.cpp --o-ods-decl=ops/LinalgNamedStructuredOps.yamlgen.td + + .PHONY: build-initialized \ build-end-to-end-jit \ concretecompiler \ @@ -226,4 +231,5 @@ update_python_version: install \ uninstall\ install_runtime_lib \ - uninstall_runtime_lib + uninstall_runtime_lib \ + generate_conv_op diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h index 0ffa52b7a..3f6bbfd46 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h @@ -11,6 +11,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include #include #include @@ -96,6 +97,10 @@ public: } // namespace mlir #define GET_OP_CLASSES +// TODO: remove this when removing the custom linalg op for Conv +// the generated code was calling functions from the mlir::linalg namespace +using namespace mlir::linalg; +// END TODO #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc" #endif diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index e82d86cbe..670c0a87c 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -3,6 +3,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td" include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.td" @@ -628,4 +631,153 @@ def ConcatOp : FHELinalg_Op<"concat"> { }]; } +class LinalgStructuredBase_Op props> + : Op, + DeclareOpInterfaceMethods, + LinalgStructuredInterface, + ReifyRankedShapedTypeOpInterface], props)> { + code structuredOpsBaseDecls = [{ + // Return whether the op accesses the iteration indices. + bool hasIndexSemantics() { + return !this->getBody()->getOps().empty(); + } + + LogicalResult reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + return cast(getOperation()).reifyResultShapes(b, + reifiedReturnShapes); + } + }]; +} + +def FhelinalgConv2DNchwFchwOp : LinalgStructuredBase_Op<"fhelinalg_conv_2d_nchw_fchw", !listconcat([AttrSizedOperandSegments], + /*extraInterfaces=*/[LinalgConvolutionOpInterface])> { + + let cppNamespace = "mlir::concretelang::FHELinalg"; + let summary = [{ Performs 2-D convolution. }]; + let description = [{ + Layout: + * Input: NCHW. + * Kernel: FCHW. + +Numeric casting is performed on the operands to the inner multiply, promoting +them to the same data type as the accumulator/output. + }]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, +RankedI64ElementsAttr<[2]>:$strides, +RankedI64ElementsAttr<[2]>:$dilations + ); + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(inputs); + $_state.addOperands(outputs); + SmallVector resultTensorTypes; + copy_if(outputs.getTypes(), + std::back_inserter(resultTensorTypes), + [](Type type) { return type.isa(); }); + $_state.addTypes(resultTensorTypes); + $_state.addAttribute( + "operand_segment_sizes", + $_builder.getI32VectorAttr({ + static_cast(inputs.size()), + static_cast(outputs.size())})); + $_state.addAttributes(attributes); + createAndFillStructuredOpRegion( + $_builder, + $_state, + TypeRange(inputs), + TypeRange(outputs)); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(inputs); + $_state.addOperands(outputs); + $_state.addTypes(resultTensorTypes); + $_state.addAttributes(attributes); + $_state.addAttribute( + "operand_segment_sizes", + $_builder.getI32VectorAttr({ + static_cast(inputs.size()), + static_cast(outputs.size())})); + createAndFillStructuredOpRegion( + $_builder, + $_state, + TypeRange(inputs), + TypeRange(outputs)); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(operands); + $_state.addAttributes(attributes); + $_state.addTypes(resultTensorTypes); + (void)$_state.addRegion(); + }]> + + , OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, "Attribute":$strides, "Attribute":$dilations, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(inputs); + $_state.addOperands(outputs); + $_state.addTypes(resultTensorTypes); + $_state.addAttribute( + "operand_segment_sizes", + $_builder.getI32VectorAttr({ + static_cast(inputs.size()), + static_cast(outputs.size())})); + createAndFillStructuredOpRegion( + $_builder, + $_state, + TypeRange(inputs), + TypeRange(outputs)); + $_state.addAttribute("strides", strides); +$_state.addAttribute("dilations", dilations); + $_state.addAttributes(attributes); + }]> + + ]; + let printer = [{ return mlir::concretelang::FHELinalg::printNamedStructuredOp(p, *this); }]; + let parser = [{ + return mlir::concretelang::FHELinalg::parseNamedStructuredOp(parser, result); + }]; + let hasFolder = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + // Auto-generated. + ArrayAttr iterator_types(); + ArrayAttr indexing_maps(); + static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); + static std::function + getRegionBuilder() { + return regionBuilder; + } + + // Generic methods. + static unsigned getNumRegionArgs(); + std::string getLibraryCallName(); + + bool hasDynamicIndexingMaps(); + LogicalResult verifyIndexingMapRequiredAttributes(); + + }]; +} + + #endif diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 847a825d2..74b1c1296 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -5,7 +5,11 @@ #include +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Parser.h" +#include "llvm/Support/FormatVariadic.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" @@ -620,6 +624,621 @@ template mlir::LogicalResult verifyMatmul(MatMulOp &op) { } return mlir::success(); } + +//===----------------------------------------------------------------------===// +// Implementation of FhelinalgConv2DNchwFchwOp +// This is a generated functions from `make generate_conv_op`, and some helpers +// from LinalgOps.cpp +//===----------------------------------------------------------------------===// +using namespace mlir; +using namespace mlir::linalg; + +ArrayAttr FhelinalgConv2DNchwFchwOp::iterator_types() { + return Builder(getContext()) + .getStrArrayAttr(SmallVector{ + getParallelIteratorTypeName(), getParallelIteratorTypeName(), + getParallelIteratorTypeName(), getParallelIteratorTypeName(), + getReductionIteratorTypeName(), getReductionIteratorTypeName(), + getReductionIteratorTypeName()}); +} + +static SmallVector +getSymbolBindings(FhelinalgConv2DNchwFchwOp self) { + MLIRContext *context = self.getContext(); + SmallVector exprs; + exprs.push_back(getAffineSymbolExpr(0, context)); + exprs.push_back(getAffineSymbolExpr(1, context)); + exprs.push_back(getAffineSymbolExpr(2, context)); + + int64_t cst3 = self.strides().getValue({0}); + exprs.push_back(getAffineConstantExpr(cst3, context)); + + exprs.push_back(getAffineSymbolExpr(4, context)); + + int64_t cst5 = self.dilations().getValue({0}); + exprs.push_back(getAffineConstantExpr(cst5, context)); + + exprs.push_back(getAffineSymbolExpr(6, context)); + + int64_t cst7 = self.strides().getValue({1}); + exprs.push_back(getAffineConstantExpr(cst7, context)); + + exprs.push_back(getAffineSymbolExpr(8, context)); + + int64_t cst9 = self.dilations().getValue({1}); + exprs.push_back(getAffineConstantExpr(cst9, context)); + + exprs.push_back(getAffineSymbolExpr(10, context)); + return exprs; +} + +ArrayAttr FhelinalgConv2DNchwFchwOp::indexing_maps() { + static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; + ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); + if (cached) + return cached; + + MLIRContext *context = getContext(); + auto symbolBindings = getSymbolBindings(*this); + SmallVector maps; + maps.push_back( + mlir::parseAttribute( + "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, " + "s7, s8, s9, s10] -> (d0, d4, d2 * s3 + d5 * s5, d3 * s7 + d6 * s9)>", + context) + .cast() + .getValue()); + maps.back() = simplifyAffineMap( + maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); + maps.push_back(mlir::parseAttribute( + "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, " + "s4, s5, s6, s7, s8, s9, s10] -> (d1, d4, d5, d6)>", + context) + .cast() + .getValue()); + maps.back() = simplifyAffineMap( + maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); + maps.push_back(mlir::parseAttribute( + "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, " + "s4, s5, s6, s7, s8, s9, s10] -> (d0, d1, d2, d3)>", + context) + .cast() + .getValue()); + maps.back() = simplifyAffineMap( + maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); + cached = Builder(context).getAffineMapArrayAttr(maps); + getOperation()->setAttr(memoizeAttr, cached); + return cached; +} + +unsigned FhelinalgConv2DNchwFchwOp::getNumRegionArgs() { return 3; } + +std::string FhelinalgConv2DNchwFchwOp::getLibraryCallName() { + return generateLibraryCallName(getOperation()); +} + +bool FhelinalgConv2DNchwFchwOp::hasDynamicIndexingMaps() { return true; } +LogicalResult FhelinalgConv2DNchwFchwOp::verifyIndexingMapRequiredAttributes() { + Operation *op = getOperation(); + + if (auto attr = op->getAttrOfType("strides")) { + if (!attr.getType().getElementType().isInteger(64)) + return op->emitError("incorrect element type for indexing map required " + "attribute 'strides'"); + if (attr.getType().getShape() != ArrayRef{2}) + return op->emitError( + "incorrect shape for indexing map required attribute 'strides'"); + } else { + return op->emitError("missing indexing map required attribute 'strides'"); + } + + if (auto attr = op->getAttrOfType("dilations")) { + if (!attr.getType().getElementType().isInteger(64)) + return op->emitError("incorrect element type for indexing map required " + "attribute 'dilations'"); + if (attr.getType().getShape() != ArrayRef{2}) + return op->emitError( + "incorrect shape for indexing map required attribute 'dilations'"); + } else { + return op->emitError("missing indexing map required attribute 'dilations'"); + } + + return success(); +} + +/// Some helpers were copied from LinalgOps.cpp + +/// Generic entry point to create the block for the region of a LinalgOp. +/// This is used by both named structured ops created by ods-gen and by manually +/// defined C++ ops. +/// This is used by both builders and parsers. +/// This function creates the block in the region with arguments corresponding +/// to the elemental types of `inputTypes` and `outputTypes`. The latter are +/// asserted to be of ShapedType. +template +static void fillStructuredOpRegion( + OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, + TypeRange outputTypes, + std::function errorHandler = nullptr); + +/// Generic entry point to create both the region and the block of a LinalgOp. +template +static void +createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, + TypeRange inputTypes, TypeRange outputTypes); + +/// Common parsing and printing used for both named structured ops created by +/// ods-gen and by manually defined C++ ops. Does not handle regions. +static ParseResult +parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, + SmallVectorImpl &inputTypes, + SmallVectorImpl &outputTypes); +template +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op); + +/// Specific parsing and printing for named structured ops created by ods-gen. +template +static ParseResult +parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes); + +static ParseResult +parseNamedStructuredOpResults(OpAsmParser &parser, + SmallVectorImpl &resultTypes); + +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result); + +static void printNamedStructuredOpResults(OpAsmPrinter &p, + TypeRange resultTypes); + +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); + +class RegionBuilderHelper { +public: + RegionBuilderHelper(MLIRContext *context, Block &block) + : context(context), block(block) {} + + // Generates operations to cast the given operand to a specified type. + // If the cast cannot be performed, a warning will be issued and the + // operand returned as-is (which will presumably yield a verification + // issue downstream). + Value cast(Type toType, Value operand, bool isUnsignedCast) { + OpBuilder builder = getBuilder(); + auto loc = operand.getLoc(); + + if (operand.getType() == toType) + return operand; + if (auto toIntType = toType.dyn_cast()) { + // If operand is floating point, cast directly to the int type. + if (operand.getType().isa()) { + if (isUnsignedCast) + return builder.create(loc, toType, operand); + return builder.create(loc, toType, operand); + } + // Cast index operands directly to the int type. + if (operand.getType().isIndex()) + return builder.create(loc, toType, operand); + if (auto fromIntType = operand.getType().dyn_cast()) { + // Either extend or truncate. + if (toIntType.getWidth() > fromIntType.getWidth()) { + if (isUnsignedCast) + return builder.create(loc, toType, operand); + return builder.create(loc, toType, operand); + } + if (toIntType.getWidth() < fromIntType.getWidth()) + return builder.create(loc, toType, operand); + } + } else if (auto toFloatType = toType.dyn_cast()) { + // If operand is integer, cast directly to the float type. + // Note that it is unclear how to cast from BF16<->FP16. + if (operand.getType().isa()) { + if (isUnsignedCast) + return builder.create(loc, toFloatType, operand); + return builder.create(loc, toFloatType, operand); + } + if (auto fromFloatType = operand.getType().dyn_cast()) { + if (toFloatType.getWidth() > fromFloatType.getWidth()) + return builder.create(loc, toFloatType, operand); + if (toFloatType.getWidth() < fromFloatType.getWidth()) + return builder.create(loc, toFloatType, operand); + } + } + + emitWarning(operand.getLoc()) << "could not cast operand of type " + << operand.getType() << " to " << toType; + return operand; + } + + Value applyfn__add(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs) && + rhs.getType().isa()) { + return builder.create(lhs.getLoc(), + rhs, lhs); + } + if (lhs.getType().isa() && + isInteger(rhs)) { + return builder.create(lhs.getLoc(), + lhs, rhs); + } + if (lhs.getType().isa() && + rhs.getType().isa()) { + return builder.create(lhs.getLoc(), + lhs, rhs); + } + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__exp(Value x) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(x)) + return builder.create(x.getLoc(), x); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__log(Value x) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(x)) + return builder.create(x.getLoc(), x); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__sub(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__mul(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (lhs.getType().isa() && + isInteger(rhs)) { + return builder.create(lhs.getLoc(), + lhs, rhs); + } + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__max(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__max_unsigned(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__min(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__min_unsigned(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + void yieldOutputs(ValueRange values) { + assert(!values.empty() && "linalg ops must yield outputs"); + if (values.empty()) + return; + Value first = values.front(); + OpBuilder builder = getBuilder(); + builder.create(first.getLoc(), values); + } + + Value constant(std::string value) { + OpBuilder builder = getBuilder(); + Location loc = builder.getUnknownLoc(); + Attribute valueAttr = parseAttribute(value, builder.getContext()); + return builder.create(loc, valueAttr.getType(), + valueAttr); + } + + Value index(int64_t dim) { + OpBuilder builder = getBuilder(); + return builder.create(builder.getUnknownLoc(), dim); + } + + Type getIntegerType(unsigned width) { + return IntegerType::get(context, width); + } + + Type getFloat32Type() { return Float32Type::get(context); } + + Type getFloat64Type() { return Float64Type::get(context); } + +private: + MLIRContext *context; + Block █ + + bool isFloatingPoint(Value value) { return value.getType().isa(); } + bool isInteger(Value value) { return value.getType().isa(); } + + OpBuilder getBuilder() { + OpBuilder builder(context); + builder.setInsertionPointToEnd(&block); + return builder; + } +}; + +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto castOp = operand.get().getDefiningOp(); + if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { + operand.set(castOp.getOperand()); + folded = true; + } + } + return success(folded); +} + +/// Generic entry point to create the block for the region of a LinalgOp. +/// This is used by both named structured ops created by ods-gen and by manually +/// defined C++ ops. +/// This is used by both builders and parsers. +/// This function creates the block in the region with arguments corresponding +/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted +/// to be ShapedType. +template +static void +fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes, + std::function errorHandler) { + assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); + + // TODO: atm all operands go through getElementTypeOrSelf, + // reconsider when we have evidence we need to. + SmallVector argTypes; + for (auto containers : {inputTypes, outputTypes}) + for (auto t : containers) + argTypes.push_back(getElementTypeOrSelf(t)); + + // RAII. + OpBuilder::InsertionGuard guard(opBuilder); + Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes); + unsigned actual = body->getNumArguments(); + unsigned expected = NamedStructuredOpType::getNumRegionArgs(); + if (expected != actual) { + if (errorHandler) + errorHandler(expected, actual); + return; + } + + opBuilder.setInsertionPointToStart(body); + ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); + NamedStructuredOpType::regionBuilder(b, *body); + + // indexing_maps is an auto-generated method. + + // iterator_types is an auto-generated method. +} + +static void getGenericEffectsImpl( + SmallVectorImpl> + &effects, + ValueRange results, ValueRange inputBuffers, ValueRange outputs) { + for (Value value : results) { + effects.emplace_back(MemoryEffects::Allocate::get(), value, + SideEffects::DefaultResource::get()); + } + for (Value value : inputBuffers) { + effects.emplace_back(MemoryEffects::Read::get(), value, + SideEffects::DefaultResource::get()); + } + for (Value value : outputs) { + effects.emplace_back(MemoryEffects::Read::get(), value, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), value, + SideEffects::DefaultResource::get()); + } +} + +/// Generic entry point to create both the region and the block of a LinalgOp. +template +void createAndFillStructuredOpRegion(OpBuilder &opBuilder, + OperationState &result, + TypeRange inputTypes, + TypeRange outputTypes) { + Region ®ion = *result.addRegion(); + fillStructuredOpRegion( + opBuilder, region, inputTypes, outputTypes, + [&](unsigned expected, unsigned actual) { + assert(expected != actual && "incorrect number of arguments"); + }); +} + +static void printNamedStructuredOpResults(OpAsmPrinter &p, + TypeRange resultTypes) { + if (resultTypes.empty()) + return; + p.printOptionalArrowTypeList(resultTypes); +} + +template +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op) { + if (!op.inputs().empty()) + p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; + if (!op.outputs().empty()) + p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; +} + +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { + p.printOptionalAttrDict( + op->getAttrs(), + /*elidedAttrs=*/{"operand_segment_sizes", + // See generated code in mlir-linalg-yaml-gen.cpp + "linalg.memoized_indexing_maps"}); + + // Printing is shared with generic ops, except for the region and + // attributes. + printCommonStructuredOpParts(p, op); + + // Results printing. + printNamedStructuredOpResults(p, op.result_tensors().getTypes()); + + // Region is elided. +} + +/// Common parsing used for both named structured ops created by ods-gen and by +/// manually defined C++ ops. Does not handle regions. +static ParseResult +parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, + SmallVectorImpl &inputTypes, + SmallVectorImpl &outputTypes) { + llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc; + SmallVector inputsOperands, outputsOperands; + + parser.parseOptionalAttrDict(result.attributes); + + if (succeeded(parser.parseOptionalKeyword("ins"))) { + if (parser.parseLParen()) + return failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands) || + parser.parseColonTypeList(inputTypes) || parser.parseRParen()) + return failure(); + } + + if (succeeded(parser.parseOptionalKeyword("outs"))) { + outputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || + parser.parseColonTypeList(outputTypes) || parser.parseRParen()) + return failure(); + } + + if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, + result.operands) || + parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, + result.operands)) + return failure(); + + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getI32VectorAttr( + {static_cast(inputsOperands.size()), + static_cast(outputsOperands.size())})); + return success(); +} + +template +static ParseResult +parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes) { + ParseResult res = success(); + OpBuilder opBuilder(parser.getContext()); + // Resolve `captures` into `capturedValues` at parse time so we can build the + // region with captures. + SmallVector capturedValues; + fillStructuredOpRegion( + opBuilder, region, inputTypes, outputTypes, + [&](unsigned expected, unsigned actual) { + res = parser.emitError( + parser.getCurrentLocation(), + llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " + "region expects {0} args, got {1}", + expected, actual)); + region.front().dump(); + }); + return res; +} + +static ParseResult +parseNamedStructuredOpResults(OpAsmParser &parser, + SmallVectorImpl &resultTypes) { + if (parser.parseOptionalArrowTypeList(resultTypes)) + return failure(); + return success(); +} + +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result) { + // TODO: Enable when ods-gen supports captures. + SmallVector inputTypes, outputTypes; + if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) + return failure(); + + // TODO: consider merging results parsing into region parsing. + // Need to wait for declarative assembly resolution to decide. + SmallVector outputTensorsTypes; + if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) + return failure(); + result.addTypes(outputTensorsTypes); + + std::unique_ptr region = std::make_unique(); + if (parseNamedStructuredOpRegion( + parser, *region, inputTypes, outputTypes)) + return failure(); + result.addRegion(std::move(region)); + + return success(); +} + +/// END OF COPY FROM LinalgOps.cpp + +void FhelinalgConv2DNchwFchwOp::regionBuilder(ImplicitLocOpBuilder &b, + Block &block) { + assert(3 > 0 && block.getNumArguments() == 3 && + "FhelinalgConv2DNchwFchwOp regionBuilder expects 3 (>=0) args"); + RegionBuilderHelper helper(block.getArgument(0).getContext(), block); + SmallVector yields; + Value value1 = + helper.cast(block.getArgument(0).getType(), block.getArgument(0), false); + Value value2 = + helper.cast(block.getArgument(1).getType(), block.getArgument(1), false); + Value value3 = helper.applyfn__mul(value1, value2); + Value value4 = helper.applyfn__add(block.getArgument(2), value3); + yields.push_back(value4); + helper.yieldOutputs(yields); +} + +LogicalResult FhelinalgConv2DNchwFchwOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +void FhelinalgConv2DNchwFchwOp::getEffects( + SmallVectorImpl> + &effects) { + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); + getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, + outputBuffers); +} + } // namespace FHELinalg } // namespace concretelang } // namespace mlir diff --git a/compiler/ops/core_named_ops.py b/compiler/ops/core_named_ops.py new file mode 100644 index 000000000..0f45f33e7 --- /dev/null +++ b/compiler/ops/core_named_ops.py @@ -0,0 +1,29 @@ +from mlir.dialects.linalg.opdsl.lang import * + +T1 = TV.T1 +T2 = TV.T2 + +Batch = S.Batch + + +@linalg_structured_op +def fhelinalg_conv_2d_nchw_fchw( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs 2-D convolution. + + Layout: + * Input: NCHW. + * Kernel: FCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += cast( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW + ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) \ No newline at end of file