|
|
|
|
@@ -6,13 +6,9 @@
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
|
|
|
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
|
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
|
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
|
|
|
#include "mlir/IR/OpImplementation.h"
|
|
|
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
|
|
|
#include "mlir/Parser/Parser.h"
|
|
|
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
|
|
|
|
|
|
|
|
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
|
|
|
|
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
|
|
|
|
|
@@ -1019,688 +1015,6 @@ mlir::LogicalResult FromElementOp::verify() {
|
|
|
|
|
return mlir::success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
// Implementation of FhelinalgConv2DNchwFchwOp
|
|
|
|
|
// This is a generated functions from `make generate_conv_op`, and some helpers
|
|
|
|
|
// from LinalgOps.cpp
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
using namespace mlir::linalg;
|
|
|
|
|
|
|
|
|
|
ArrayAttr FhelinalgConv2DNchwFchwOp::iterator_types() {
|
|
|
|
|
return Builder(getContext())
|
|
|
|
|
.getStrArrayAttr(SmallVector<StringRef>{
|
|
|
|
|
getParallelIteratorTypeName(), getParallelIteratorTypeName(),
|
|
|
|
|
getParallelIteratorTypeName(), getParallelIteratorTypeName(),
|
|
|
|
|
getReductionIteratorTypeName(), getReductionIteratorTypeName(),
|
|
|
|
|
getReductionIteratorTypeName()});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static SmallVector<AffineExpr>
|
|
|
|
|
getSymbolBindings(FhelinalgConv2DNchwFchwOp self) {
|
|
|
|
|
MLIRContext *context = self.getContext();
|
|
|
|
|
SmallVector<AffineExpr> exprs;
|
|
|
|
|
exprs.push_back(getAffineSymbolExpr(0, context));
|
|
|
|
|
exprs.push_back(getAffineSymbolExpr(1, context));
|
|
|
|
|
exprs.push_back(getAffineSymbolExpr(2, context));
|
|
|
|
|
|
|
|
|
|
int64_t cst3 = self.strides().getValues<int64_t>()[{0}];
|
|
|
|
|
exprs.push_back(getAffineConstantExpr(cst3, context));
|
|
|
|
|
|
|
|
|
|
exprs.push_back(getAffineSymbolExpr(4, context));
|
|
|
|
|
|
|
|
|
|
int64_t cst5 = self.dilations().getValues<int64_t>()[{0}];
|
|
|
|
|
exprs.push_back(getAffineConstantExpr(cst5, context));
|
|
|
|
|
|
|
|
|
|
exprs.push_back(getAffineSymbolExpr(6, context));
|
|
|
|
|
|
|
|
|
|
int64_t cst7 = self.strides().getValues<int64_t>()[{1}];
|
|
|
|
|
exprs.push_back(getAffineConstantExpr(cst7, context));
|
|
|
|
|
|
|
|
|
|
exprs.push_back(getAffineSymbolExpr(8, context));
|
|
|
|
|
|
|
|
|
|
int64_t cst9 = self.dilations().getValues<int64_t>()[{1}];
|
|
|
|
|
exprs.push_back(getAffineConstantExpr(cst9, context));
|
|
|
|
|
|
|
|
|
|
exprs.push_back(getAffineSymbolExpr(10, context));
|
|
|
|
|
return exprs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ArrayAttr FhelinalgConv2DNchwFchwOp::indexing_maps() {
|
|
|
|
|
static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
|
|
|
|
|
ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
|
|
|
|
|
if (cached)
|
|
|
|
|
return cached;
|
|
|
|
|
|
|
|
|
|
MLIRContext *context = getContext();
|
|
|
|
|
auto symbolBindings = getSymbolBindings(*this);
|
|
|
|
|
SmallVector<AffineMap> maps;
|
|
|
|
|
maps.push_back(
|
|
|
|
|
mlir::parseAttribute(
|
|
|
|
|
"affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, "
|
|
|
|
|
"s7, s8, s9, s10] -> (d0, d4, d2 * s3 + d5 * s5, d3 * s7 + d6 * s9)>",
|
|
|
|
|
context)
|
|
|
|
|
.cast<AffineMapAttr>()
|
|
|
|
|
.getValue());
|
|
|
|
|
maps.back() = simplifyAffineMap(
|
|
|
|
|
maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0));
|
|
|
|
|
maps.push_back(mlir::parseAttribute(
|
|
|
|
|
"affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, "
|
|
|
|
|
"s4, s5, s6, s7, s8, s9, s10] -> (d1, d4, d5, d6)>",
|
|
|
|
|
context)
|
|
|
|
|
.cast<AffineMapAttr>()
|
|
|
|
|
.getValue());
|
|
|
|
|
maps.back() = simplifyAffineMap(
|
|
|
|
|
maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0));
|
|
|
|
|
maps.push_back(mlir::parseAttribute(
|
|
|
|
|
"affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, "
|
|
|
|
|
"s4, s5, s6, s7, s8, s9, s10] -> (d0, d1, d2, d3)>",
|
|
|
|
|
context)
|
|
|
|
|
.cast<AffineMapAttr>()
|
|
|
|
|
.getValue());
|
|
|
|
|
maps.back() = simplifyAffineMap(
|
|
|
|
|
maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0));
|
|
|
|
|
cached = Builder(context).getAffineMapArrayAttr(maps);
|
|
|
|
|
getOperation()->setAttr(memoizeAttr, cached);
|
|
|
|
|
return cached;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
unsigned FhelinalgConv2DNchwFchwOp::getNumRegionArgs() { return 3; }
|
|
|
|
|
|
|
|
|
|
std::string FhelinalgConv2DNchwFchwOp::getLibraryCallName() {
|
|
|
|
|
return generateLibraryCallName(getOperation());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FhelinalgConv2DNchwFchwOp::hasDynamicIndexingMaps() { return true; }
|
|
|
|
|
LogicalResult FhelinalgConv2DNchwFchwOp::verifyIndexingMapRequiredAttributes() {
|
|
|
|
|
Operation *op = getOperation();
|
|
|
|
|
|
|
|
|
|
if (auto attr = op->getAttrOfType<DenseElementsAttr>("strides")) {
|
|
|
|
|
if (!attr.getType().getElementType().isInteger(64))
|
|
|
|
|
return op->emitError("incorrect element type for indexing map required "
|
|
|
|
|
"attribute 'strides'");
|
|
|
|
|
if (attr.getType().getShape() != ArrayRef<int64_t>{2})
|
|
|
|
|
return op->emitError(
|
|
|
|
|
"incorrect shape for indexing map required attribute 'strides'");
|
|
|
|
|
} else {
|
|
|
|
|
return op->emitError("missing indexing map required attribute 'strides'");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (auto attr = op->getAttrOfType<DenseElementsAttr>("dilations")) {
|
|
|
|
|
if (!attr.getType().getElementType().isInteger(64))
|
|
|
|
|
return op->emitError("incorrect element type for indexing map required "
|
|
|
|
|
"attribute 'dilations'");
|
|
|
|
|
if (attr.getType().getShape() != ArrayRef<int64_t>{2})
|
|
|
|
|
return op->emitError(
|
|
|
|
|
"incorrect shape for indexing map required attribute 'dilations'");
|
|
|
|
|
} else {
|
|
|
|
|
return op->emitError("missing indexing map required attribute 'dilations'");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
|
|
|
|
|
// LLVM Exceptions
|
|
|
|
|
using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
|
|
|
|
|
ArrayRef<NamedAttribute>)>;
|
|
|
|
|
|
|
|
|
|
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
|
|
|
|
|
// LLVM Exceptions
|
|
|
|
|
static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
|
|
|
|
TypeRange resultTypes) {
|
|
|
|
|
if (resultTypes.empty())
|
|
|
|
|
return;
|
|
|
|
|
p.printOptionalArrowTypeList(resultTypes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
|
|
|
|
|
// LLVM Exceptions
|
|
|
|
|
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
|
|
|
|
|
ValueRange outputs) {
|
|
|
|
|
if (!inputs.empty())
|
|
|
|
|
p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
|
|
|
|
|
if (!outputs.empty())
|
|
|
|
|
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
|
|
|
|
|
// LLVM Exceptions
|
|
|
|
|
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
|
|
|
|
|
ValueRange inputs, ValueRange outputs) {
|
|
|
|
|
p.printOptionalAttrDict(
|
|
|
|
|
op->getAttrs(),
|
|
|
|
|
/*elidedAttrs=*/{"operand_segment_sizes",
|
|
|
|
|
// See generated code in mlir-linalg-yaml-gen.cpp
|
|
|
|
|
"linalg.memoized_indexing_maps"});
|
|
|
|
|
|
|
|
|
|
// Printing is shared with generic ops, except for the region and
|
|
|
|
|
// attributes.
|
|
|
|
|
printCommonStructuredOpParts(p, inputs, outputs);
|
|
|
|
|
|
|
|
|
|
// Results printing.
|
|
|
|
|
printNamedStructuredOpResults(p, op->getResultTypes());
|
|
|
|
|
|
|
|
|
|
// Region is elided.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FhelinalgConv2DNchwFchwOp::print(mlir::OpAsmPrinter &p) {
|
|
|
|
|
printNamedStructuredOp(p, this->getOperation(),
|
|
|
|
|
this->getOperation()->getOperands(),
|
|
|
|
|
this->getOperation()->getResults());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
|
|
|
|
|
// LLVM Exceptions
|
|
|
|
|
static ParseResult
|
|
|
|
|
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
|
|
|
|
SmallVectorImpl<Type> &inputTypes,
|
|
|
|
|
SmallVectorImpl<Type> &outputTypes) {
|
|
|
|
|
SMLoc inputsOperandsLoc, outputsOperandsLoc;
|
|
|
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
|
|
|
|
|
outputsOperands;
|
|
|
|
|
|
|
|
|
|
parser.parseOptionalAttrDict(result.attributes);
|
|
|
|
|
|
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("ins"))) {
|
|
|
|
|
if (parser.parseLParen())
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
inputsOperandsLoc = parser.getCurrentLocation();
|
|
|
|
|
if (parser.parseOperandList(inputsOperands) ||
|
|
|
|
|
parser.parseColonTypeList(inputTypes) || parser.parseRParen())
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("outs"))) {
|
|
|
|
|
outputsOperandsLoc = parser.getCurrentLocation();
|
|
|
|
|
if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
|
|
|
|
|
parser.parseColonTypeList(outputTypes) || parser.parseRParen())
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
|
|
|
|
|
result.operands) ||
|
|
|
|
|
parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
|
|
|
|
|
result.operands))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
result.addAttribute("operand_segment_sizes",
|
|
|
|
|
parser.getBuilder().getI32VectorAttr(
|
|
|
|
|
{static_cast<int32_t>(inputsOperands.size()),
|
|
|
|
|
static_cast<int32_t>(outputsOperands.size())}));
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
|
|
|
|
|
// LLVM Exceptions
|
|
|
|
|
static ParseResult
|
|
|
|
|
parseNamedStructuredOpResults(OpAsmParser &parser,
|
|
|
|
|
SmallVectorImpl<Type> &resultTypes) {
|
|
|
|
|
if (parser.parseOptionalArrowTypeList(resultTypes))
|
|
|
|
|
return failure();
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
|
|
|
|
|
// LLVM Exceptions
|
|
|
|
|
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
|
|
|
|
TypeRange inputTypes, TypeRange outputTypes,
|
|
|
|
|
ArrayRef<NamedAttribute> attrs,
|
|
|
|
|
RegionBuilderFn regionBuilder) {
|
|
|
|
|
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
|
|
|
|
|
|
|
|
|
|
// TODO: atm all operands go through getElementTypeOrSelf,
|
|
|
|
|
// reconsider when we have evidence we need to.
|
|
|
|
|
SmallVector<Type, 8> argTypes;
|
|
|
|
|
SmallVector<Location, 8> argLocs;
|
|
|
|
|
for (auto containers : {inputTypes, outputTypes}) {
|
|
|
|
|
for (auto t : containers) {
|
|
|
|
|
argTypes.push_back(getElementTypeOrSelf(t));
|
|
|
|
|
|
|
|
|
|
// TODO: Pass in a proper location here.
|
|
|
|
|
argLocs.push_back(opBuilder.getUnknownLoc());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RAII.
|
|
|
|
|
OpBuilder::InsertionGuard guard(opBuilder);
|
|
|
|
|
Block *body =
|
|
|
|
|
opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs);
|
|
|
|
|
|
|
|
|
|
opBuilder.setInsertionPointToStart(body);
|
|
|
|
|
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
|
|
|
|
|
regionBuilder(b, *body, attrs);
|
|
|
|
|
|
|
|
|
|
// indexing_maps is an auto-generated method.
|
|
|
|
|
|
|
|
|
|
// iterator_types is an auto-generated method.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
|
|
|
|
|
// LLVM Exceptions
|
|
|
|
|
static ParseResult parseNamedStructuredOpRegion(
|
|
|
|
|
OpAsmParser &parser, Region ®ion, unsigned numRegionArgs,
|
|
|
|
|
TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
|
|
|
|
|
RegionBuilderFn regionBuilder) {
|
|
|
|
|
if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
|
|
|
|
|
return parser.emitError(
|
|
|
|
|
parser.getCurrentLocation(),
|
|
|
|
|
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
|
|
|
|
|
"region expects {0} args, got {1}",
|
|
|
|
|
numRegionArgs, inputTypes.size() + outputTypes.size()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpBuilder opBuilder(parser.getContext());
|
|
|
|
|
fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
|
|
|
|
|
regionBuilder);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
|
|
|
|
|
// LLVM Exceptions
|
|
|
|
|
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result,
|
|
|
|
|
unsigned numRegionArgs,
|
|
|
|
|
RegionBuilderFn regionBuilder) {
|
|
|
|
|
// TODO: Enable when ods-gen supports captures.
|
|
|
|
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
|
|
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
// TODO: consider merging results parsing into region parsing.
|
|
|
|
|
// Need to wait for declarative assembly resolution to decide.
|
|
|
|
|
SmallVector<Type, 1> outputTensorsTypes;
|
|
|
|
|
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
|
|
|
|
|
return failure();
|
|
|
|
|
result.addTypes(outputTensorsTypes);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
|
|
|
if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
|
|
|
|
|
outputTypes, result.attributes.getAttrs(),
|
|
|
|
|
regionBuilder))
|
|
|
|
|
return failure();
|
|
|
|
|
result.addRegion(std::move(region));
|
|
|
|
|
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mlir::ParseResult
|
|
|
|
|
FhelinalgConv2DNchwFchwOp::parse(mlir::OpAsmParser &parser,
|
|
|
|
|
mlir::OperationState &result) {
|
|
|
|
|
return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
|
|
|
|
|
getRegionBuilder());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Some helpers were copied from LinalgOps.cpp
|
|
|
|
|
|
|
|
|
|
/// Generic entry point to create the block for the region of a LinalgOp.
|
|
|
|
|
/// This is used by both named structured ops created by ods-gen and by manually
|
|
|
|
|
/// defined C++ ops.
|
|
|
|
|
/// This is used by both builders and parsers.
|
|
|
|
|
/// This function creates the block in the region with arguments corresponding
|
|
|
|
|
/// to the elemental types of `inputTypes` and `outputTypes`. The latter are
|
|
|
|
|
/// asserted to be of ShapedType.
|
|
|
|
|
template <typename NamedStructuredOpType>
|
|
|
|
|
static void fillStructuredOpRegion(
|
|
|
|
|
OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
|
|
|
|
|
TypeRange outputTypes,
|
|
|
|
|
std::function<void(unsigned, unsigned)> errorHandler = nullptr);
|
|
|
|
|
|
|
|
|
|
/// Generic entry point to create both the region and the block of a LinalgOp.
|
|
|
|
|
template <typename NamedStructuredOpType>
|
|
|
|
|
static void
|
|
|
|
|
createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
|
|
|
|
|
TypeRange inputTypes, TypeRange outputTypes);
|
|
|
|
|
|
|
|
|
|
/// Common parsing and printing used for both named structured ops created by
|
|
|
|
|
/// ods-gen and by manually defined C++ ops. Does not handle regions.
|
|
|
|
|
static ParseResult
|
|
|
|
|
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
|
|
|
|
SmallVectorImpl<Type> &inputTypes,
|
|
|
|
|
SmallVectorImpl<Type> &outputTypes);
|
|
|
|
|
template <typename NamedStructuredOpType>
|
|
|
|
|
static void printCommonStructuredOpParts(OpAsmPrinter &p,
|
|
|
|
|
NamedStructuredOpType op);
|
|
|
|
|
|
|
|
|
|
/// Specific parsing and printing for named structured ops created by ods-gen.
|
|
|
|
|
template <typename NamedStructuredOpType>
|
|
|
|
|
static ParseResult
|
|
|
|
|
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
|
|
|
|
|
TypeRange inputTypes, TypeRange outputTypes);
|
|
|
|
|
|
|
|
|
|
static ParseResult
|
|
|
|
|
parseNamedStructuredOpResults(OpAsmParser &parser,
|
|
|
|
|
SmallVectorImpl<Type> &resultTypes);
|
|
|
|
|
|
|
|
|
|
template <typename NamedStructuredOpType>
|
|
|
|
|
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
|
|
|
|
OperationState &result);
|
|
|
|
|
|
|
|
|
|
static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
|
|
|
|
TypeRange resultTypes);
|
|
|
|
|
|
|
|
|
|
template <typename NamedStructuredOpType>
|
|
|
|
|
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
|
|
|
|
|
|
|
|
|
|
class RegionBuilderHelper {
|
|
|
|
|
public:
|
|
|
|
|
RegionBuilderHelper(MLIRContext *context, Block &block)
|
|
|
|
|
: context(context), block(block) {}
|
|
|
|
|
|
|
|
|
|
// Generates operations to cast the given operand to a specified type.
|
|
|
|
|
// If the cast cannot be performed, a warning will be issued and the
|
|
|
|
|
// operand returned as-is (which will presumably yield a verification
|
|
|
|
|
// issue downstream).
|
|
|
|
|
Value cast(Type toType, Value operand, bool isUnsignedCast) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
auto loc = operand.getLoc();
|
|
|
|
|
|
|
|
|
|
if (operand.getType() == toType)
|
|
|
|
|
return operand;
|
|
|
|
|
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
|
|
|
|
|
// If operand is floating point, cast directly to the int type.
|
|
|
|
|
if (operand.getType().isa<FloatType>()) {
|
|
|
|
|
if (isUnsignedCast)
|
|
|
|
|
return builder.create<arith::FPToUIOp>(loc, toType, operand);
|
|
|
|
|
return builder.create<arith::FPToSIOp>(loc, toType, operand);
|
|
|
|
|
}
|
|
|
|
|
// Cast index operands directly to the int type.
|
|
|
|
|
if (operand.getType().isIndex())
|
|
|
|
|
return builder.create<arith::IndexCastOp>(loc, toType, operand);
|
|
|
|
|
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
|
|
|
|
|
// Either extend or truncate.
|
|
|
|
|
if (toIntType.getWidth() > fromIntType.getWidth()) {
|
|
|
|
|
if (isUnsignedCast)
|
|
|
|
|
return builder.create<arith::ExtUIOp>(loc, toType, operand);
|
|
|
|
|
return builder.create<arith::ExtSIOp>(loc, toType, operand);
|
|
|
|
|
}
|
|
|
|
|
if (toIntType.getWidth() < fromIntType.getWidth())
|
|
|
|
|
return builder.create<arith::TruncIOp>(loc, toType, operand);
|
|
|
|
|
}
|
|
|
|
|
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
|
|
|
|
|
// If operand is integer, cast directly to the float type.
|
|
|
|
|
// Note that it is unclear how to cast from BF16<->FP16.
|
|
|
|
|
if (operand.getType().isa<IntegerType>()) {
|
|
|
|
|
if (isUnsignedCast)
|
|
|
|
|
return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
|
|
|
|
|
return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
|
|
|
|
|
}
|
|
|
|
|
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
|
|
|
|
|
if (toFloatType.getWidth() > fromFloatType.getWidth())
|
|
|
|
|
return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
|
|
|
|
|
if (toFloatType.getWidth() < fromFloatType.getWidth())
|
|
|
|
|
return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
emitWarning(operand.getLoc()) << "could not cast operand of type "
|
|
|
|
|
<< operand.getType() << " to " << toType;
|
|
|
|
|
return operand;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value applyfn__add(Value lhs, Value rhs) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
if (isFloatingPoint(lhs))
|
|
|
|
|
return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
if (isInteger(lhs) &&
|
|
|
|
|
rhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
|
|
|
|
return builder.create<mlir::concretelang::FHE::AddEintIntOp>(lhs.getLoc(),
|
|
|
|
|
rhs, lhs);
|
|
|
|
|
}
|
|
|
|
|
if (lhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() &&
|
|
|
|
|
isInteger(rhs)) {
|
|
|
|
|
return builder.create<mlir::concretelang::FHE::AddEintIntOp>(lhs.getLoc(),
|
|
|
|
|
lhs, rhs);
|
|
|
|
|
}
|
|
|
|
|
if (lhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() &&
|
|
|
|
|
rhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
|
|
|
|
return builder.create<mlir::concretelang::FHE::AddEintOp>(lhs.getLoc(),
|
|
|
|
|
lhs, rhs);
|
|
|
|
|
}
|
|
|
|
|
llvm_unreachable("unsupported non numeric type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value applyfn__exp(Value x) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
if (isFloatingPoint(x))
|
|
|
|
|
return builder.create<math::ExpOp>(x.getLoc(), x);
|
|
|
|
|
llvm_unreachable("unsupported non numeric type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value applyfn__log(Value x) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
if (isFloatingPoint(x))
|
|
|
|
|
return builder.create<math::LogOp>(x.getLoc(), x);
|
|
|
|
|
llvm_unreachable("unsupported non numeric type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value applyfn__sub(Value lhs, Value rhs) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
if (isFloatingPoint(lhs))
|
|
|
|
|
return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
if (isInteger(lhs))
|
|
|
|
|
return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
llvm_unreachable("unsupported non numeric type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value applyfn__mul(Value lhs, Value rhs) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
if (isFloatingPoint(lhs))
|
|
|
|
|
return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
if (isInteger(lhs))
|
|
|
|
|
return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
if (lhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() &&
|
|
|
|
|
isInteger(rhs)) {
|
|
|
|
|
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(lhs.getLoc(),
|
|
|
|
|
lhs, rhs);
|
|
|
|
|
}
|
|
|
|
|
llvm_unreachable("unsupported non numeric type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value applyfn__max(Value lhs, Value rhs) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
if (isFloatingPoint(lhs))
|
|
|
|
|
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
if (isInteger(lhs))
|
|
|
|
|
return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
llvm_unreachable("unsupported non numeric type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value applyfn__max_unsigned(Value lhs, Value rhs) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
if (isFloatingPoint(lhs))
|
|
|
|
|
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
if (isInteger(lhs))
|
|
|
|
|
return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
llvm_unreachable("unsupported non numeric type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value applyfn__min(Value lhs, Value rhs) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
if (isFloatingPoint(lhs))
|
|
|
|
|
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
if (isInteger(lhs))
|
|
|
|
|
return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
llvm_unreachable("unsupported non numeric type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value applyfn__min_unsigned(Value lhs, Value rhs) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
if (isFloatingPoint(lhs))
|
|
|
|
|
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
if (isInteger(lhs))
|
|
|
|
|
return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs);
|
|
|
|
|
llvm_unreachable("unsupported non numeric type");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void yieldOutputs(ValueRange values) {
|
|
|
|
|
assert(!values.empty() && "linalg ops must yield outputs");
|
|
|
|
|
if (values.empty())
|
|
|
|
|
return;
|
|
|
|
|
Value first = values.front();
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
builder.create<YieldOp>(first.getLoc(), values);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value constant(std::string value) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
Location loc = builder.getUnknownLoc();
|
|
|
|
|
Attribute valueAttr = parseAttribute(value, builder.getContext());
|
|
|
|
|
return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
|
|
|
|
|
valueAttr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value index(int64_t dim) {
|
|
|
|
|
OpBuilder builder = getBuilder();
|
|
|
|
|
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type getIntegerType(unsigned width) {
|
|
|
|
|
return IntegerType::get(context, width);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Type getFloat32Type() { return Float32Type::get(context); }
|
|
|
|
|
|
|
|
|
|
Type getFloat64Type() { return Float64Type::get(context); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
MLIRContext *context;
|
|
|
|
|
Block █
|
|
|
|
|
|
|
|
|
|
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
|
|
|
|
|
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
|
|
|
|
|
|
|
|
|
|
OpBuilder getBuilder() {
|
|
|
|
|
OpBuilder builder(context);
|
|
|
|
|
builder.setInsertionPointToEnd(&block);
|
|
|
|
|
return builder;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static LogicalResult foldMemRefCast(Operation *op) {
|
|
|
|
|
bool folded = false;
|
|
|
|
|
for (OpOperand &operand : op->getOpOperands()) {
|
|
|
|
|
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
|
|
|
|
|
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
|
|
|
|
|
operand.set(castOp.getOperand());
|
|
|
|
|
folded = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return success(folded);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Generic entry point to create the block for the region of a LinalgOp.
|
|
|
|
|
/// This is used by both named structured ops created by ods-gen and by manually
|
|
|
|
|
/// defined C++ ops.
|
|
|
|
|
/// This is used by both builders and parsers.
|
|
|
|
|
/// This function creates the block in the region with arguments corresponding
|
|
|
|
|
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
|
|
|
|
|
/// to be ShapedType.
|
|
|
|
|
template <typename NamedStructuredOpType>
|
|
|
|
|
static void
|
|
|
|
|
fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
|
|
|
|
TypeRange inputTypes, TypeRange outputTypes,
|
|
|
|
|
std::function<void(unsigned, unsigned)> errorHandler) {
|
|
|
|
|
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
|
|
|
|
|
|
|
|
|
|
// TODO: atm all operands go through getElementTypeOrSelf,
|
|
|
|
|
// reconsider when we have evidence we need to.
|
|
|
|
|
SmallVector<Type, 8> argTypes;
|
|
|
|
|
SmallVector<Location, 8> argLocs;
|
|
|
|
|
for (auto containers : {inputTypes, outputTypes})
|
|
|
|
|
for (auto t : containers) {
|
|
|
|
|
argTypes.push_back(getElementTypeOrSelf(t));
|
|
|
|
|
argLocs.push_back(opBuilder.getUnknownLoc());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// RAII.
|
|
|
|
|
OpBuilder::InsertionGuard guard(opBuilder);
|
|
|
|
|
Block *body =
|
|
|
|
|
opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs);
|
|
|
|
|
unsigned actual = body->getNumArguments();
|
|
|
|
|
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
|
|
|
|
|
if (expected != actual) {
|
|
|
|
|
if (errorHandler)
|
|
|
|
|
errorHandler(expected, actual);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
opBuilder.setInsertionPointToStart(body);
|
|
|
|
|
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
|
|
|
|
|
NamedStructuredOpType::regionBuilder(b, *body, {});
|
|
|
|
|
|
|
|
|
|
// indexing_maps is an auto-generated method.
|
|
|
|
|
|
|
|
|
|
// iterator_types is an auto-generated method.
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void getGenericEffectsImpl(
|
|
|
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
|
|
|
&effects,
|
|
|
|
|
ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
|
|
|
|
|
for (Value value : results) {
|
|
|
|
|
effects.emplace_back(MemoryEffects::Allocate::get(), value,
|
|
|
|
|
SideEffects::DefaultResource::get());
|
|
|
|
|
}
|
|
|
|
|
for (Value value : inputBuffers) {
|
|
|
|
|
effects.emplace_back(MemoryEffects::Read::get(), value,
|
|
|
|
|
SideEffects::DefaultResource::get());
|
|
|
|
|
}
|
|
|
|
|
for (Value value : outputs) {
|
|
|
|
|
effects.emplace_back(MemoryEffects::Read::get(), value,
|
|
|
|
|
SideEffects::DefaultResource::get());
|
|
|
|
|
effects.emplace_back(MemoryEffects::Write::get(), value,
|
|
|
|
|
SideEffects::DefaultResource::get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Generic entry point to create both the region and the block of a LinalgOp.
|
|
|
|
|
template <typename NamedStructuredOpType>
|
|
|
|
|
void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
|
|
|
|
|
OperationState &result,
|
|
|
|
|
TypeRange inputTypes,
|
|
|
|
|
TypeRange outputTypes) {
|
|
|
|
|
Region ®ion = *result.addRegion();
|
|
|
|
|
fillStructuredOpRegion<NamedStructuredOpType>(
|
|
|
|
|
opBuilder, region, inputTypes, outputTypes,
|
|
|
|
|
[&](unsigned expected, unsigned actual) {
|
|
|
|
|
assert(expected != actual && "incorrect number of arguments");
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// END OF COPY FROM LinalgOps.cpp
|
|
|
|
|
|
|
|
|
|
void FhelinalgConv2DNchwFchwOp::regionBuilder(
|
|
|
|
|
ImplicitLocOpBuilder &b, Block &block,
|
|
|
|
|
llvm::ArrayRef<mlir::NamedAttribute>) {
|
|
|
|
|
assert(3 > 0 && block.getNumArguments() == 3 &&
|
|
|
|
|
"FhelinalgConv2DNchwFchwOp regionBuilder expects 3 (>=0) args");
|
|
|
|
|
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
|
|
|
|
|
SmallVector<Value> yields;
|
|
|
|
|
Value value1 =
|
|
|
|
|
helper.cast(block.getArgument(0).getType(), block.getArgument(0), false);
|
|
|
|
|
Value value2 =
|
|
|
|
|
helper.cast(block.getArgument(1).getType(), block.getArgument(1), false);
|
|
|
|
|
Value value3 = helper.applyfn__mul(value1, value2);
|
|
|
|
|
Value value4 = helper.applyfn__add(block.getArgument(2), value3);
|
|
|
|
|
yields.push_back(value4);
|
|
|
|
|
helper.yieldOutputs(yields);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LogicalResult FhelinalgConv2DNchwFchwOp::fold(ArrayRef<Attribute>,
|
|
|
|
|
SmallVectorImpl<OpFoldResult> &) {
|
|
|
|
|
return foldMemRefCast(*this);
|
|
|
|
|
}
|
|
|
|
|
void FhelinalgConv2DNchwFchwOp::getEffects(
|
|
|
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
|
|
|
&effects) {
|
|
|
|
|
SmallVector<Value> inputBuffers = getInputBufferOperands();
|
|
|
|
|
SmallVector<Value> outputBuffers = getOutputBufferOperands();
|
|
|
|
|
getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
|
|
|
|
|
outputBuffers);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Verify the transpose shapes
|
|
|
|
|
mlir::LogicalResult TransposeOp::verify() {
|
|
|
|
|
mlir::Type tensorTy = ((mlir::Type)this->tensor().getType());
|
|
|
|
|
|