diff --git a/compiler/Makefile b/compiler/Makefile index ed9e99028..5be572c83 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -203,6 +203,11 @@ release_tarballs: update_python_version: echo "__version__ = \"`git describe --tags --abbrev=0 | grep -e '[0-9].*' -o`\"" > lib/Bindings/Python/version.txt +generate_conv_op: + python -m mlir.dialects.linalg.opdsl.dump_oplib ops.core_named_ops > ops/LinalgNamedStructuredOps.yaml + $(BUILD_DIR)/bin/mlir-linalg-ods-yaml-gen ops/LinalgNamedStructuredOps.yaml --o-impl=ops/LinalgOps.cpp --o-ods-decl=ops/LinalgNamedStructuredOps.yamlgen.td + + .PHONY: build-initialized \ build-end-to-end-jit \ concretecompiler \ @@ -226,4 +231,5 @@ update_python_version: install \ uninstall\ install_runtime_lib \ - uninstall_runtime_lib + uninstall_runtime_lib \ + generate_conv_op diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h index 0ffa52b7a..3f6bbfd46 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h @@ -11,6 +11,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include #include #include @@ -96,6 +97,10 @@ public: } // namespace mlir #define GET_OP_CLASSES +// TODO: remove this when removing the custom linalg op for Conv +// the generated code was calling functions from the mlir::linalg namespace +using namespace mlir::linalg; +// END TODO #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc" #endif diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index e82d86cbe..670c0a87c 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -3,6 +3,9 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td" include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.td" @@ -628,4 +631,153 @@ def ConcatOp : FHELinalg_Op<"concat"> { }]; } +class LinalgStructuredBase_Op props> + : Op, + DeclareOpInterfaceMethods, + LinalgStructuredInterface, + ReifyRankedShapedTypeOpInterface], props)> { + code structuredOpsBaseDecls = [{ + // Return whether the op accesses the iteration indices. + bool hasIndexSemantics() { + return !this->getBody()->getOps().empty(); + } + + LogicalResult reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + return cast(getOperation()).reifyResultShapes(b, + reifiedReturnShapes); + } + }]; +} + +def FhelinalgConv2DNchwFchwOp : LinalgStructuredBase_Op<"fhelinalg_conv_2d_nchw_fchw", !listconcat([AttrSizedOperandSegments], + /*extraInterfaces=*/[LinalgConvolutionOpInterface])> { + + let cppNamespace = "mlir::concretelang::FHELinalg"; + let summary = [{ Performs 2-D convolution. }]; + let description = [{ + Layout: + * Input: NCHW. + * Kernel: FCHW. + +Numeric casting is performed on the operands to the inner multiply, promoting +them to the same data type as the accumulator/output. + }]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, +RankedI64ElementsAttr<[2]>:$strides, +RankedI64ElementsAttr<[2]>:$dilations + ); + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder< + (ins "ValueRange":$inputs, "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(inputs); + $_state.addOperands(outputs); + SmallVector resultTensorTypes; + copy_if(outputs.getTypes(), + std::back_inserter(resultTensorTypes), + [](Type type) { return type.isa(); }); + $_state.addTypes(resultTensorTypes); + $_state.addAttribute( + "operand_segment_sizes", + $_builder.getI32VectorAttr({ + static_cast(inputs.size()), + static_cast(outputs.size())})); + $_state.addAttributes(attributes); + createAndFillStructuredOpRegion( + $_builder, + $_state, + TypeRange(inputs), + TypeRange(outputs)); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(inputs); + $_state.addOperands(outputs); + $_state.addTypes(resultTensorTypes); + $_state.addAttributes(attributes); + $_state.addAttribute( + "operand_segment_sizes", + $_builder.getI32VectorAttr({ + static_cast(inputs.size()), + static_cast(outputs.size())})); + createAndFillStructuredOpRegion( + $_builder, + $_state, + TypeRange(inputs), + TypeRange(outputs)); + }]>, + OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(operands); + $_state.addAttributes(attributes); + $_state.addTypes(resultTensorTypes); + (void)$_state.addRegion(); + }]> + + , OpBuilder< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs, "Attribute":$strides, "Attribute":$dilations, + CArg<"ArrayRef", "{}">:$attributes), + [{ + $_state.addOperands(inputs); + $_state.addOperands(outputs); + $_state.addTypes(resultTensorTypes); + $_state.addAttribute( + "operand_segment_sizes", + $_builder.getI32VectorAttr({ + static_cast(inputs.size()), + static_cast(outputs.size())})); + createAndFillStructuredOpRegion( + $_builder, + $_state, + TypeRange(inputs), + TypeRange(outputs)); + $_state.addAttribute("strides", strides); +$_state.addAttribute("dilations", dilations); + $_state.addAttributes(attributes); + }]> + + ]; + let printer = [{ return mlir::concretelang::FHELinalg::printNamedStructuredOp(p, *this); }]; + let parser = [{ + return mlir::concretelang::FHELinalg::parseNamedStructuredOp(parser, result); + }]; + let hasFolder = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + // Auto-generated. + ArrayAttr iterator_types(); + ArrayAttr indexing_maps(); + static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); + static std::function + getRegionBuilder() { + return regionBuilder; + } + + // Generic methods. + static unsigned getNumRegionArgs(); + std::string getLibraryCallName(); + + bool hasDynamicIndexingMaps(); + LogicalResult verifyIndexingMapRequiredAttributes(); + + }]; +} + + #endif diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 847a825d2..74b1c1296 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -5,7 +5,11 @@ #include +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Parser.h" +#include "llvm/Support/FormatVariadic.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" @@ -620,6 +624,621 @@ template mlir::LogicalResult verifyMatmul(MatMulOp &op) { } return mlir::success(); } + +//===----------------------------------------------------------------------===// +// Implementation of FhelinalgConv2DNchwFchwOp +// This is a generated functions from `make generate_conv_op`, and some helpers +// from LinalgOps.cpp +//===----------------------------------------------------------------------===// +using namespace mlir; +using namespace mlir::linalg; + +ArrayAttr FhelinalgConv2DNchwFchwOp::iterator_types() { + return Builder(getContext()) + .getStrArrayAttr(SmallVector{ + getParallelIteratorTypeName(), getParallelIteratorTypeName(), + getParallelIteratorTypeName(), getParallelIteratorTypeName(), + getReductionIteratorTypeName(), getReductionIteratorTypeName(), + getReductionIteratorTypeName()}); +} + +static SmallVector +getSymbolBindings(FhelinalgConv2DNchwFchwOp self) { + MLIRContext *context = self.getContext(); + SmallVector exprs; + exprs.push_back(getAffineSymbolExpr(0, context)); + exprs.push_back(getAffineSymbolExpr(1, context)); + exprs.push_back(getAffineSymbolExpr(2, context)); + + int64_t cst3 = self.strides().getValue({0}); + exprs.push_back(getAffineConstantExpr(cst3, context)); + + exprs.push_back(getAffineSymbolExpr(4, context)); + + int64_t cst5 = self.dilations().getValue({0}); + exprs.push_back(getAffineConstantExpr(cst5, context)); + + exprs.push_back(getAffineSymbolExpr(6, context)); + + int64_t cst7 = self.strides().getValue({1}); + exprs.push_back(getAffineConstantExpr(cst7, context)); + + exprs.push_back(getAffineSymbolExpr(8, context)); + + int64_t cst9 = self.dilations().getValue({1}); + exprs.push_back(getAffineConstantExpr(cst9, context)); + + exprs.push_back(getAffineSymbolExpr(10, context)); + return exprs; +} + +ArrayAttr FhelinalgConv2DNchwFchwOp::indexing_maps() { + static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; + ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); + if (cached) + return cached; + + MLIRContext *context = getContext(); + auto symbolBindings = getSymbolBindings(*this); + SmallVector maps; + maps.push_back( + mlir::parseAttribute( + "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, " + "s7, s8, s9, s10] -> (d0, d4, d2 * s3 + d5 * s5, d3 * s7 + d6 * s9)>", + context) + .cast() + .getValue()); + maps.back() = simplifyAffineMap( + maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); + maps.push_back(mlir::parseAttribute( + "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, " + "s4, s5, s6, s7, s8, s9, s10] -> (d1, d4, d5, d6)>", + context) + .cast() + .getValue()); + maps.back() = simplifyAffineMap( + maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); + maps.push_back(mlir::parseAttribute( + "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, " + "s4, s5, s6, s7, s8, s9, s10] -> (d0, d1, d2, d3)>", + context) + .cast() + .getValue()); + maps.back() = simplifyAffineMap( + maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); + cached = Builder(context).getAffineMapArrayAttr(maps); + getOperation()->setAttr(memoizeAttr, cached); + return cached; +} + +unsigned FhelinalgConv2DNchwFchwOp::getNumRegionArgs() { return 3; } + +std::string FhelinalgConv2DNchwFchwOp::getLibraryCallName() { + return generateLibraryCallName(getOperation()); +} + +bool FhelinalgConv2DNchwFchwOp::hasDynamicIndexingMaps() { return true; } +LogicalResult FhelinalgConv2DNchwFchwOp::verifyIndexingMapRequiredAttributes() { + Operation *op = getOperation(); + + if (auto attr = op->getAttrOfType("strides")) { + if (!attr.getType().getElementType().isInteger(64)) + return op->emitError("incorrect element type for indexing map required " + "attribute 'strides'"); + if (attr.getType().getShape() != ArrayRef{2}) + return op->emitError( + "incorrect shape for indexing map required attribute 'strides'"); + } else { + return op->emitError("missing indexing map required attribute 'strides'"); + } + + if (auto attr = op->getAttrOfType("dilations")) { + if (!attr.getType().getElementType().isInteger(64)) + return op->emitError("incorrect element type for indexing map required " + "attribute 'dilations'"); + if (attr.getType().getShape() != ArrayRef{2}) + return op->emitError( + "incorrect shape for indexing map required attribute 'dilations'"); + } else { + return op->emitError("missing indexing map required attribute 'dilations'"); + } + + return success(); +} + +/// Some helpers were copied from LinalgOps.cpp + +/// Generic entry point to create the block for the region of a LinalgOp. +/// This is used by both named structured ops created by ods-gen and by manually +/// defined C++ ops. +/// This is used by both builders and parsers. +/// This function creates the block in the region with arguments corresponding +/// to the elemental types of `inputTypes` and `outputTypes`. The latter are +/// asserted to be of ShapedType. +template +static void fillStructuredOpRegion( + OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, + TypeRange outputTypes, + std::function errorHandler = nullptr); + +/// Generic entry point to create both the region and the block of a LinalgOp. +template +static void +createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, + TypeRange inputTypes, TypeRange outputTypes); + +/// Common parsing and printing used for both named structured ops created by +/// ods-gen and by manually defined C++ ops. Does not handle regions. +static ParseResult +parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, + SmallVectorImpl &inputTypes, + SmallVectorImpl &outputTypes); +template +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op); + +/// Specific parsing and printing for named structured ops created by ods-gen. +template +static ParseResult +parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes); + +static ParseResult +parseNamedStructuredOpResults(OpAsmParser &parser, + SmallVectorImpl &resultTypes); + +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result); + +static void printNamedStructuredOpResults(OpAsmPrinter &p, + TypeRange resultTypes); + +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); + +class RegionBuilderHelper { +public: + RegionBuilderHelper(MLIRContext *context, Block &block) + : context(context), block(block) {} + + // Generates operations to cast the given operand to a specified type. + // If the cast cannot be performed, a warning will be issued and the + // operand returned as-is (which will presumably yield a verification + // issue downstream). + Value cast(Type toType, Value operand, bool isUnsignedCast) { + OpBuilder builder = getBuilder(); + auto loc = operand.getLoc(); + + if (operand.getType() == toType) + return operand; + if (auto toIntType = toType.dyn_cast()) { + // If operand is floating point, cast directly to the int type. + if (operand.getType().isa()) { + if (isUnsignedCast) + return builder.create(loc, toType, operand); + return builder.create(loc, toType, operand); + } + // Cast index operands directly to the int type. + if (operand.getType().isIndex()) + return builder.create(loc, toType, operand); + if (auto fromIntType = operand.getType().dyn_cast()) { + // Either extend or truncate. + if (toIntType.getWidth() > fromIntType.getWidth()) { + if (isUnsignedCast) + return builder.create(loc, toType, operand); + return builder.create(loc, toType, operand); + } + if (toIntType.getWidth() < fromIntType.getWidth()) + return builder.create(loc, toType, operand); + } + } else if (auto toFloatType = toType.dyn_cast()) { + // If operand is integer, cast directly to the float type. + // Note that it is unclear how to cast from BF16<->FP16. + if (operand.getType().isa()) { + if (isUnsignedCast) + return builder.create(loc, toFloatType, operand); + return builder.create(loc, toFloatType, operand); + } + if (auto fromFloatType = operand.getType().dyn_cast()) { + if (toFloatType.getWidth() > fromFloatType.getWidth()) + return builder.create(loc, toFloatType, operand); + if (toFloatType.getWidth() < fromFloatType.getWidth()) + return builder.create(loc, toFloatType, operand); + } + } + + emitWarning(operand.getLoc()) << "could not cast operand of type " + << operand.getType() << " to " << toType; + return operand; + } + + Value applyfn__add(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs) && + rhs.getType().isa()) { + return builder.create(lhs.getLoc(), + rhs, lhs); + } + if (lhs.getType().isa() && + isInteger(rhs)) { + return builder.create(lhs.getLoc(), + lhs, rhs); + } + if (lhs.getType().isa() && + rhs.getType().isa()) { + return builder.create(lhs.getLoc(), + lhs, rhs); + } + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__exp(Value x) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(x)) + return builder.create(x.getLoc(), x); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__log(Value x) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(x)) + return builder.create(x.getLoc(), x); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__sub(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__mul(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (lhs.getType().isa() && + isInteger(rhs)) { + return builder.create(lhs.getLoc(), + lhs, rhs); + } + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__max(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__max_unsigned(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__min(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__min_unsigned(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + void yieldOutputs(ValueRange values) { + assert(!values.empty() && "linalg ops must yield outputs"); + if (values.empty()) + return; + Value first = values.front(); + OpBuilder builder = getBuilder(); + builder.create(first.getLoc(), values); + } + + Value constant(std::string value) { + OpBuilder builder = getBuilder(); + Location loc = builder.getUnknownLoc(); + Attribute valueAttr = parseAttribute(value, builder.getContext()); + return builder.create(loc, valueAttr.getType(), + valueAttr); + } + + Value index(int64_t dim) { + OpBuilder builder = getBuilder(); + return builder.create(builder.getUnknownLoc(), dim); + } + + Type getIntegerType(unsigned width) { + return IntegerType::get(context, width); + } + + Type getFloat32Type() { return Float32Type::get(context); } + + Type getFloat64Type() { return Float64Type::get(context); } + +private: + MLIRContext *context; + Block █ + + bool isFloatingPoint(Value value) { return value.getType().isa(); } + bool isInteger(Value value) { return value.getType().isa(); } + + OpBuilder getBuilder() { + OpBuilder builder(context); + builder.setInsertionPointToEnd(&block); + return builder; + } +}; + +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto castOp = operand.get().getDefiningOp(); + if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { + operand.set(castOp.getOperand()); + folded = true; + } + } + return success(folded); +} + +/// Generic entry point to create the block for the region of a LinalgOp. +/// This is used by both named structured ops created by ods-gen and by manually +/// defined C++ ops. +/// This is used by both builders and parsers. +/// This function creates the block in the region with arguments corresponding +/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted +/// to be ShapedType. +template +static void +fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes, + std::function errorHandler) { + assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); + + // TODO: atm all operands go through getElementTypeOrSelf, + // reconsider when we have evidence we need to. + SmallVector argTypes; + for (auto containers : {inputTypes, outputTypes}) + for (auto t : containers) + argTypes.push_back(getElementTypeOrSelf(t)); + + // RAII. + OpBuilder::InsertionGuard guard(opBuilder); + Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes); + unsigned actual = body->getNumArguments(); + unsigned expected = NamedStructuredOpType::getNumRegionArgs(); + if (expected != actual) { + if (errorHandler) + errorHandler(expected, actual); + return; + } + + opBuilder.setInsertionPointToStart(body); + ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); + NamedStructuredOpType::regionBuilder(b, *body); + + // indexing_maps is an auto-generated method. + + // iterator_types is an auto-generated method. +} + +static void getGenericEffectsImpl( + SmallVectorImpl> + &effects, + ValueRange results, ValueRange inputBuffers, ValueRange outputs) { + for (Value value : results) { + effects.emplace_back(MemoryEffects::Allocate::get(), value, + SideEffects::DefaultResource::get()); + } + for (Value value : inputBuffers) { + effects.emplace_back(MemoryEffects::Read::get(), value, + SideEffects::DefaultResource::get()); + } + for (Value value : outputs) { + effects.emplace_back(MemoryEffects::Read::get(), value, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), value, + SideEffects::DefaultResource::get()); + } +} + +/// Generic entry point to create both the region and the block of a LinalgOp. +template +void createAndFillStructuredOpRegion(OpBuilder &opBuilder, + OperationState &result, + TypeRange inputTypes, + TypeRange outputTypes) { + Region ®ion = *result.addRegion(); + fillStructuredOpRegion( + opBuilder, region, inputTypes, outputTypes, + [&](unsigned expected, unsigned actual) { + assert(expected != actual && "incorrect number of arguments"); + }); +} + +static void printNamedStructuredOpResults(OpAsmPrinter &p, + TypeRange resultTypes) { + if (resultTypes.empty()) + return; + p.printOptionalArrowTypeList(resultTypes); +} + +template +static void printCommonStructuredOpParts(OpAsmPrinter &p, + NamedStructuredOpType op) { + if (!op.inputs().empty()) + p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; + if (!op.outputs().empty()) + p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; +} + +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { + p.printOptionalAttrDict( + op->getAttrs(), + /*elidedAttrs=*/{"operand_segment_sizes", + // See generated code in mlir-linalg-yaml-gen.cpp + "linalg.memoized_indexing_maps"}); + + // Printing is shared with generic ops, except for the region and + // attributes. + printCommonStructuredOpParts(p, op); + + // Results printing. + printNamedStructuredOpResults(p, op.result_tensors().getTypes()); + + // Region is elided. +} + +/// Common parsing used for both named structured ops created by ods-gen and by +/// manually defined C++ ops. Does not handle regions. +static ParseResult +parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, + SmallVectorImpl &inputTypes, + SmallVectorImpl &outputTypes) { + llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc; + SmallVector inputsOperands, outputsOperands; + + parser.parseOptionalAttrDict(result.attributes); + + if (succeeded(parser.parseOptionalKeyword("ins"))) { + if (parser.parseLParen()) + return failure(); + + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands) || + parser.parseColonTypeList(inputTypes) || parser.parseRParen()) + return failure(); + } + + if (succeeded(parser.parseOptionalKeyword("outs"))) { + outputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || + parser.parseColonTypeList(outputTypes) || parser.parseRParen()) + return failure(); + } + + if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, + result.operands) || + parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, + result.operands)) + return failure(); + + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getI32VectorAttr( + {static_cast(inputsOperands.size()), + static_cast(outputsOperands.size())})); + return success(); +} + +template +static ParseResult +parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, + TypeRange inputTypes, TypeRange outputTypes) { + ParseResult res = success(); + OpBuilder opBuilder(parser.getContext()); + // Resolve `captures` into `capturedValues` at parse time so we can build the + // region with captures. + SmallVector capturedValues; + fillStructuredOpRegion( + opBuilder, region, inputTypes, outputTypes, + [&](unsigned expected, unsigned actual) { + res = parser.emitError( + parser.getCurrentLocation(), + llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " + "region expects {0} args, got {1}", + expected, actual)); + region.front().dump(); + }); + return res; +} + +static ParseResult +parseNamedStructuredOpResults(OpAsmParser &parser, + SmallVectorImpl &resultTypes) { + if (parser.parseOptionalArrowTypeList(resultTypes)) + return failure(); + return success(); +} + +template +static ParseResult parseNamedStructuredOp(OpAsmParser &parser, + OperationState &result) { + // TODO: Enable when ods-gen supports captures. + SmallVector inputTypes, outputTypes; + if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) + return failure(); + + // TODO: consider merging results parsing into region parsing. + // Need to wait for declarative assembly resolution to decide. + SmallVector outputTensorsTypes; + if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) + return failure(); + result.addTypes(outputTensorsTypes); + + std::unique_ptr region = std::make_unique(); + if (parseNamedStructuredOpRegion( + parser, *region, inputTypes, outputTypes)) + return failure(); + result.addRegion(std::move(region)); + + return success(); +} + +/// END OF COPY FROM LinalgOps.cpp + +void FhelinalgConv2DNchwFchwOp::regionBuilder(ImplicitLocOpBuilder &b, + Block &block) { + assert(3 > 0 && block.getNumArguments() == 3 && + "FhelinalgConv2DNchwFchwOp regionBuilder expects 3 (>=0) args"); + RegionBuilderHelper helper(block.getArgument(0).getContext(), block); + SmallVector yields; + Value value1 = + helper.cast(block.getArgument(0).getType(), block.getArgument(0), false); + Value value2 = + helper.cast(block.getArgument(1).getType(), block.getArgument(1), false); + Value value3 = helper.applyfn__mul(value1, value2); + Value value4 = helper.applyfn__add(block.getArgument(2), value3); + yields.push_back(value4); + helper.yieldOutputs(yields); +} + +LogicalResult FhelinalgConv2DNchwFchwOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +void FhelinalgConv2DNchwFchwOp::getEffects( + SmallVectorImpl> + &effects) { + SmallVector inputBuffers = getInputBufferOperands(); + SmallVector outputBuffers = getOutputBufferOperands(); + getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, + outputBuffers); +} + } // namespace FHELinalg } // namespace concretelang } // namespace mlir diff --git a/compiler/ops/core_named_ops.py b/compiler/ops/core_named_ops.py new file mode 100644 index 000000000..0f45f33e7 --- /dev/null +++ b/compiler/ops/core_named_ops.py @@ -0,0 +1,29 @@ +from mlir.dialects.linalg.opdsl.lang import * + +T1 = TV.T1 +T2 = TV.T2 + +Batch = S.Batch + + +@linalg_structured_op +def fhelinalg_conv_2d_nchw_fchw( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs 2-D convolution. + + Layout: + * Input: NCHW. + * Kernel: FCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += cast( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW + ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) \ No newline at end of file