refactor: remove code related to our custom conv2D named op

it was introduced as a workaround while linalng couldn't support other
types than int/float
This commit is contained in:
youben11
2022-06-23 14:18:59 +01:00
committed by Ayoub Benaissa
parent f1f1db923d
commit 63d84a3e4a
6 changed files with 1 additions and 873 deletions

View File

@@ -240,10 +240,6 @@ release_tarballs:
update_python_version:
echo "__version__ = \"`git describe --tags --abbrev=0 | grep -e '[0-9].*' -o`\"" > lib/Bindings/Python/version.txt
generate_conv_op:
python -m mlir.dialects.linalg.opdsl.dump_oplib ops.core_named_ops > ops/LinalgNamedStructuredOps.yaml
$(BUILD_DIR)/bin/mlir-linalg-ods-yaml-gen ops/LinalgNamedStructuredOps.yaml --o-impl=ops/LinalgOps.cpp --o-ods-decl=ops/LinalgNamedStructuredOps.yamlgen.td
check_python_format:
black --check tests/python/ lib/Bindings/Python/concrete/
@@ -266,7 +262,6 @@ python_lint:
package_py310 \
release_tarballs \
update_python_version \
generate_conv_op \
python_lint \
python_format \
check_python_format \

View File

@@ -97,10 +97,6 @@ public:
} // namespace mlir
#define GET_OP_CLASSES
// TODO: remove this when removing the custom linalg op for Conv
// the generated code was calling functions from the mlir::linalg namespace
using namespace mlir::linalg;
// END TODO
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc"
#endif

View File

@@ -3,9 +3,6 @@
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td"
include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.td"
@@ -946,151 +943,6 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
let hasVerifier = 1;
}
class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
: Op<Linalg_Dialect, mnemonic, !listconcat([
SingleBlockImplicitTerminator<"YieldOp">,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
LinalgStructuredInterface,
ReifyRankedShapedTypeOpInterface], props)> {
code structuredOpsBaseDecls = [{
// Return whether the op accesses the iteration indices.
bool hasIndexSemantics() {
return !this->getBody()->getOps<IndexOp>().empty();
}
LogicalResult reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgOp>(getOperation()).reifyResultShapes(b,
reifiedReturnShapes);
}
}];
}
def FhelinalgConv2DNchwFchwOp : LinalgStructuredBase_Op<"fhelinalg_conv_2d_nchw_fchw", !listconcat([AttrSizedOperandSegments],
/*extraInterfaces=*/[LinalgConvolutionOpInterface])> {
let cppNamespace = "mlir::concretelang::FHELinalg";
let summary = [{ Performs 2-D convolution. }];
let description = [{
Layout:
* Input: NCHW.
* Kernel: FCHW.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
}];
let arguments = (ins
Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
RankedI64ElementsAttr<[2]>:$strides,
RankedI64ElementsAttr<[2]>:$dilations
);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addOperands(inputs);
$_state.addOperands(outputs);
SmallVector<Type> resultTensorTypes;
copy_if(outputs.getTypes(),
std::back_inserter(resultTensorTypes),
[](Type type) { return type.isa<RankedTensorType>(); });
$_state.addTypes(resultTensorTypes);
$_state.addAttribute(
"operand_segment_sizes",
$_builder.getI32VectorAttr({
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
$_state.addAttributes(attributes);
createAndFillStructuredOpRegion<FhelinalgConv2DNchwFchwOp>(
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs));
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addOperands(inputs);
$_state.addOperands(outputs);
$_state.addTypes(resultTensorTypes);
$_state.addAttributes(attributes);
$_state.addAttribute(
"operand_segment_sizes",
$_builder.getI32VectorAttr({
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
createAndFillStructuredOpRegion<FhelinalgConv2DNchwFchwOp>(
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs));
}]>,
OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addOperands(operands);
$_state.addAttributes(attributes);
$_state.addTypes(resultTensorTypes);
(void)$_state.addRegion();
}]>
, OpBuilder<
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
"ValueRange":$outputs, "Attribute":$strides, "Attribute":$dilations,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
$_state.addOperands(inputs);
$_state.addOperands(outputs);
$_state.addTypes(resultTensorTypes);
$_state.addAttribute(
"operand_segment_sizes",
$_builder.getI32VectorAttr({
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
createAndFillStructuredOpRegion<FhelinalgConv2DNchwFchwOp>(
$_builder,
$_state,
TypeRange(inputs),
TypeRange(outputs));
$_state.addAttribute("strides", strides);
$_state.addAttribute("dilations", dilations);
$_state.addAttributes(attributes);
}]>
];
let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let extraClassDeclaration = structuredOpsBaseDecls # [{
// Auto-generated.
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, llvm::ArrayRef<mlir::NamedAttribute>);
static std::function<void(ImplicitLocOpBuilder &b, Block &, llvm::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}
// Generic methods.
static unsigned getNumRegionArgs();
std::string getLibraryCallName();
bool hasDynamicIndexingMaps();
LogicalResult verifyIndexingMapRequiredAttributes();
}];
}
def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> {
let summary = "Returns a tensor that contains the transposition of the input tensor.";

View File

@@ -6,13 +6,9 @@
#include <unordered_set>
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Parser/Parser.h"
#include "llvm/Support/FormatVariadic.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
@@ -1019,688 +1015,6 @@ mlir::LogicalResult FromElementOp::verify() {
return mlir::success();
}
//===----------------------------------------------------------------------===//
// Implementation of FhelinalgConv2DNchwFchwOp
// This is a generated functions from `make generate_conv_op`, and some helpers
// from LinalgOps.cpp
//===----------------------------------------------------------------------===//
using namespace mlir;
using namespace mlir::linalg;
ArrayAttr FhelinalgConv2DNchwFchwOp::iterator_types() {
return Builder(getContext())
.getStrArrayAttr(SmallVector<StringRef>{
getParallelIteratorTypeName(), getParallelIteratorTypeName(),
getParallelIteratorTypeName(), getParallelIteratorTypeName(),
getReductionIteratorTypeName(), getReductionIteratorTypeName(),
getReductionIteratorTypeName()});
}
static SmallVector<AffineExpr>
getSymbolBindings(FhelinalgConv2DNchwFchwOp self) {
MLIRContext *context = self.getContext();
SmallVector<AffineExpr> exprs;
exprs.push_back(getAffineSymbolExpr(0, context));
exprs.push_back(getAffineSymbolExpr(1, context));
exprs.push_back(getAffineSymbolExpr(2, context));
int64_t cst3 = self.strides().getValues<int64_t>()[{0}];
exprs.push_back(getAffineConstantExpr(cst3, context));
exprs.push_back(getAffineSymbolExpr(4, context));
int64_t cst5 = self.dilations().getValues<int64_t>()[{0}];
exprs.push_back(getAffineConstantExpr(cst5, context));
exprs.push_back(getAffineSymbolExpr(6, context));
int64_t cst7 = self.strides().getValues<int64_t>()[{1}];
exprs.push_back(getAffineConstantExpr(cst7, context));
exprs.push_back(getAffineSymbolExpr(8, context));
int64_t cst9 = self.dilations().getValues<int64_t>()[{1}];
exprs.push_back(getAffineConstantExpr(cst9, context));
exprs.push_back(getAffineSymbolExpr(10, context));
return exprs;
}
ArrayAttr FhelinalgConv2DNchwFchwOp::indexing_maps() {
static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
if (cached)
return cached;
MLIRContext *context = getContext();
auto symbolBindings = getSymbolBindings(*this);
SmallVector<AffineMap> maps;
maps.push_back(
mlir::parseAttribute(
"affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, "
"s7, s8, s9, s10] -> (d0, d4, d2 * s3 + d5 * s5, d3 * s7 + d6 * s9)>",
context)
.cast<AffineMapAttr>()
.getValue());
maps.back() = simplifyAffineMap(
maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0));
maps.push_back(mlir::parseAttribute(
"affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, "
"s4, s5, s6, s7, s8, s9, s10] -> (d1, d4, d5, d6)>",
context)
.cast<AffineMapAttr>()
.getValue());
maps.back() = simplifyAffineMap(
maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0));
maps.push_back(mlir::parseAttribute(
"affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, "
"s4, s5, s6, s7, s8, s9, s10] -> (d0, d1, d2, d3)>",
context)
.cast<AffineMapAttr>()
.getValue());
maps.back() = simplifyAffineMap(
maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0));
cached = Builder(context).getAffineMapArrayAttr(maps);
getOperation()->setAttr(memoizeAttr, cached);
return cached;
}
unsigned FhelinalgConv2DNchwFchwOp::getNumRegionArgs() { return 3; }
std::string FhelinalgConv2DNchwFchwOp::getLibraryCallName() {
return generateLibraryCallName(getOperation());
}
bool FhelinalgConv2DNchwFchwOp::hasDynamicIndexingMaps() { return true; }
LogicalResult FhelinalgConv2DNchwFchwOp::verifyIndexingMapRequiredAttributes() {
Operation *op = getOperation();
if (auto attr = op->getAttrOfType<DenseElementsAttr>("strides")) {
if (!attr.getType().getElementType().isInteger(64))
return op->emitError("incorrect element type for indexing map required "
"attribute 'strides'");
if (attr.getType().getShape() != ArrayRef<int64_t>{2})
return op->emitError(
"incorrect shape for indexing map required attribute 'strides'");
} else {
return op->emitError("missing indexing map required attribute 'strides'");
}
if (auto attr = op->getAttrOfType<DenseElementsAttr>("dilations")) {
if (!attr.getType().getElementType().isInteger(64))
return op->emitError("incorrect element type for indexing map required "
"attribute 'dilations'");
if (attr.getType().getShape() != ArrayRef<int64_t>{2})
return op->emitError(
"incorrect shape for indexing map required attribute 'dilations'");
} else {
return op->emitError("missing indexing map required attribute 'dilations'");
}
return success();
}
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
// LLVM Exceptions
using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
ArrayRef<NamedAttribute>)>;
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
// LLVM Exceptions
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes) {
if (resultTypes.empty())
return;
p.printOptionalArrowTypeList(resultTypes);
}
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
// LLVM Exceptions
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
ValueRange outputs) {
if (!inputs.empty())
p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
if (!outputs.empty())
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
}
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
// LLVM Exceptions
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
ValueRange inputs, ValueRange outputs) {
p.printOptionalAttrDict(
op->getAttrs(),
/*elidedAttrs=*/{"operand_segment_sizes",
// See generated code in mlir-linalg-yaml-gen.cpp
"linalg.memoized_indexing_maps"});
// Printing is shared with generic ops, except for the region and
// attributes.
printCommonStructuredOpParts(p, inputs, outputs);
// Results printing.
printNamedStructuredOpResults(p, op->getResultTypes());
// Region is elided.
}
void FhelinalgConv2DNchwFchwOp::print(mlir::OpAsmPrinter &p) {
printNamedStructuredOp(p, this->getOperation(),
this->getOperation()->getOperands(),
this->getOperation()->getResults());
}
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
// LLVM Exceptions
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &outputTypes) {
SMLoc inputsOperandsLoc, outputsOperandsLoc;
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
outputsOperands;
parser.parseOptionalAttrDict(result.attributes);
if (succeeded(parser.parseOptionalKeyword("ins"))) {
if (parser.parseLParen())
return failure();
inputsOperandsLoc = parser.getCurrentLocation();
if (parser.parseOperandList(inputsOperands) ||
parser.parseColonTypeList(inputTypes) || parser.parseRParen())
return failure();
}
if (succeeded(parser.parseOptionalKeyword("outs"))) {
outputsOperandsLoc = parser.getCurrentLocation();
if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
parser.parseColonTypeList(outputTypes) || parser.parseRParen())
return failure();
}
if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
result.operands) ||
parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
result.operands))
return failure();
result.addAttribute("operand_segment_sizes",
parser.getBuilder().getI32VectorAttr(
{static_cast<int32_t>(inputsOperands.size()),
static_cast<int32_t>(outputsOperands.size())}));
return success();
}
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
// LLVM Exceptions
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes) {
if (parser.parseOptionalArrowTypeList(resultTypes))
return failure();
return success();
}
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
// LLVM Exceptions
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<NamedAttribute> attrs,
RegionBuilderFn regionBuilder) {
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
// TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to.
SmallVector<Type, 8> argTypes;
SmallVector<Location, 8> argLocs;
for (auto containers : {inputTypes, outputTypes}) {
for (auto t : containers) {
argTypes.push_back(getElementTypeOrSelf(t));
// TODO: Pass in a proper location here.
argLocs.push_back(opBuilder.getUnknownLoc());
}
}
// RAII.
OpBuilder::InsertionGuard guard(opBuilder);
Block *body =
opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
opBuilder.setInsertionPointToStart(body);
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
regionBuilder(b, *body, attrs);
// indexing_maps is an auto-generated method.
// iterator_types is an auto-generated method.
}
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
// LLVM Exceptions
static ParseResult parseNamedStructuredOpRegion(
OpAsmParser &parser, Region &region, unsigned numRegionArgs,
TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
RegionBuilderFn regionBuilder) {
if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
return parser.emitError(
parser.getCurrentLocation(),
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
"region expects {0} args, got {1}",
numRegionArgs, inputTypes.size() + outputTypes.size()));
}
OpBuilder opBuilder(parser.getContext());
fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
regionBuilder);
return success();
}
// Copied from LinalgOps.cpp; license is: Apache License v2.0 with
// LLVM Exceptions
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result,
unsigned numRegionArgs,
RegionBuilderFn regionBuilder) {
// TODO: Enable when ods-gen supports captures.
SmallVector<Type, 1> inputTypes, outputTypes;
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
return failure();
// TODO: consider merging results parsing into region parsing.
// Need to wait for declarative assembly resolution to decide.
SmallVector<Type, 1> outputTensorsTypes;
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
return failure();
result.addTypes(outputTensorsTypes);
std::unique_ptr<Region> region = std::make_unique<Region>();
if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
outputTypes, result.attributes.getAttrs(),
regionBuilder))
return failure();
result.addRegion(std::move(region));
return success();
}
mlir::ParseResult
FhelinalgConv2DNchwFchwOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
getRegionBuilder());
}
/// Some helpers were copied from LinalgOps.cpp
/// Generic entry point to create the block for the region of a LinalgOp.
/// This is used by both named structured ops created by ods-gen and by manually
/// defined C++ ops.
/// This is used by both builders and parsers.
/// This function creates the block in the region with arguments corresponding
/// to the elemental types of `inputTypes` and `outputTypes`. The latter are
/// asserted to be of ShapedType.
template <typename NamedStructuredOpType>
static void fillStructuredOpRegion(
OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
TypeRange outputTypes,
std::function<void(unsigned, unsigned)> errorHandler = nullptr);
/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
static void
createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
TypeRange inputTypes, TypeRange outputTypes);
/// Common parsing and printing used for both named structured ops created by
/// ods-gen and by manually defined C++ ops. Does not handle regions.
static ParseResult
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<Type> &outputTypes);
template <typename NamedStructuredOpType>
static void printCommonStructuredOpParts(OpAsmPrinter &p,
NamedStructuredOpType op);
/// Specific parsing and printing for named structured ops created by ods-gen.
template <typename NamedStructuredOpType>
static ParseResult
parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
TypeRange inputTypes, TypeRange outputTypes);
static ParseResult
parseNamedStructuredOpResults(OpAsmParser &parser,
SmallVectorImpl<Type> &resultTypes);
template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result);
static void printNamedStructuredOpResults(OpAsmPrinter &p,
TypeRange resultTypes);
template <typename NamedStructuredOpType>
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
class RegionBuilderHelper {
public:
RegionBuilderHelper(MLIRContext *context, Block &block)
: context(context), block(block) {}
// Generates operations to cast the given operand to a specified type.
// If the cast cannot be performed, a warning will be issued and the
// operand returned as-is (which will presumably yield a verification
// issue downstream).
Value cast(Type toType, Value operand, bool isUnsignedCast) {
OpBuilder builder = getBuilder();
auto loc = operand.getLoc();
if (operand.getType() == toType)
return operand;
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
// If operand is floating point, cast directly to the int type.
if (operand.getType().isa<FloatType>()) {
if (isUnsignedCast)
return builder.create<arith::FPToUIOp>(loc, toType, operand);
return builder.create<arith::FPToSIOp>(loc, toType, operand);
}
// Cast index operands directly to the int type.
if (operand.getType().isIndex())
return builder.create<arith::IndexCastOp>(loc, toType, operand);
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
// Either extend or truncate.
if (toIntType.getWidth() > fromIntType.getWidth()) {
if (isUnsignedCast)
return builder.create<arith::ExtUIOp>(loc, toType, operand);
return builder.create<arith::ExtSIOp>(loc, toType, operand);
}
if (toIntType.getWidth() < fromIntType.getWidth())
return builder.create<arith::TruncIOp>(loc, toType, operand);
}
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
// If operand is integer, cast directly to the float type.
// Note that it is unclear how to cast from BF16<->FP16.
if (operand.getType().isa<IntegerType>()) {
if (isUnsignedCast)
return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
}
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
if (toFloatType.getWidth() > fromFloatType.getWidth())
return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
if (toFloatType.getWidth() < fromFloatType.getWidth())
return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
}
}
emitWarning(operand.getLoc()) << "could not cast operand of type "
<< operand.getType() << " to " << toType;
return operand;
}
Value applyfn__add(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs) &&
rhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
return builder.create<mlir::concretelang::FHE::AddEintIntOp>(lhs.getLoc(),
rhs, lhs);
}
if (lhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() &&
isInteger(rhs)) {
return builder.create<mlir::concretelang::FHE::AddEintIntOp>(lhs.getLoc(),
lhs, rhs);
}
if (lhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() &&
rhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
return builder.create<mlir::concretelang::FHE::AddEintOp>(lhs.getLoc(),
lhs, rhs);
}
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__exp(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::ExpOp>(x.getLoc(), x);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__log(Value x) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(x))
return builder.create<math::LogOp>(x.getLoc(), x);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__sub(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__mul(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs);
if (lhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() &&
isInteger(rhs)) {
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(lhs.getLoc(),
lhs, rhs);
}
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__max(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__max_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__min(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__min_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
void yieldOutputs(ValueRange values) {
assert(!values.empty() && "linalg ops must yield outputs");
if (values.empty())
return;
Value first = values.front();
OpBuilder builder = getBuilder();
builder.create<YieldOp>(first.getLoc(), values);
}
Value constant(std::string value) {
OpBuilder builder = getBuilder();
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
valueAttr);
}
Value index(int64_t dim) {
OpBuilder builder = getBuilder();
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
}
Type getIntegerType(unsigned width) {
return IntegerType::get(context, width);
}
Type getFloat32Type() { return Float32Type::get(context); }
Type getFloat64Type() { return Float64Type::get(context); }
private:
MLIRContext *context;
Block &block;
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
OpBuilder getBuilder() {
OpBuilder builder(context);
builder.setInsertionPointToEnd(&block);
return builder;
}
};
static LogicalResult foldMemRefCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
}
return success(folded);
}
/// Generic entry point to create the block for the region of a LinalgOp.
/// This is used by both named structured ops created by ods-gen and by manually
/// defined C++ ops.
/// This is used by both builders and parsers.
/// This function creates the block in the region with arguments corresponding
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
/// to be ShapedType.
template <typename NamedStructuredOpType>
static void
fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
TypeRange inputTypes, TypeRange outputTypes,
std::function<void(unsigned, unsigned)> errorHandler) {
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
// TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to.
SmallVector<Type, 8> argTypes;
SmallVector<Location, 8> argLocs;
for (auto containers : {inputTypes, outputTypes})
for (auto t : containers) {
argTypes.push_back(getElementTypeOrSelf(t));
argLocs.push_back(opBuilder.getUnknownLoc());
}
// RAII.
OpBuilder::InsertionGuard guard(opBuilder);
Block *body =
opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
unsigned actual = body->getNumArguments();
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
if (expected != actual) {
if (errorHandler)
errorHandler(expected, actual);
return;
}
opBuilder.setInsertionPointToStart(body);
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
NamedStructuredOpType::regionBuilder(b, *body, {});
// indexing_maps is an auto-generated method.
// iterator_types is an auto-generated method.
}
static void getGenericEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
for (Value value : results) {
effects.emplace_back(MemoryEffects::Allocate::get(), value,
SideEffects::DefaultResource::get());
}
for (Value value : inputBuffers) {
effects.emplace_back(MemoryEffects::Read::get(), value,
SideEffects::DefaultResource::get());
}
for (Value value : outputs) {
effects.emplace_back(MemoryEffects::Read::get(), value,
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), value,
SideEffects::DefaultResource::get());
}
}
/// Generic entry point to create both the region and the block of a LinalgOp.
template <typename NamedStructuredOpType>
void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
OperationState &result,
TypeRange inputTypes,
TypeRange outputTypes) {
Region &region = *result.addRegion();
fillStructuredOpRegion<NamedStructuredOpType>(
opBuilder, region, inputTypes, outputTypes,
[&](unsigned expected, unsigned actual) {
assert(expected != actual && "incorrect number of arguments");
});
}
/// END OF COPY FROM LinalgOps.cpp
void FhelinalgConv2DNchwFchwOp::regionBuilder(
ImplicitLocOpBuilder &b, Block &block,
llvm::ArrayRef<mlir::NamedAttribute>) {
assert(3 > 0 && block.getNumArguments() == 3 &&
"FhelinalgConv2DNchwFchwOp regionBuilder expects 3 (>=0) args");
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
SmallVector<Value> yields;
Value value1 =
helper.cast(block.getArgument(0).getType(), block.getArgument(0), false);
Value value2 =
helper.cast(block.getArgument(1).getType(), block.getArgument(1), false);
Value value3 = helper.applyfn__mul(value1, value2);
Value value4 = helper.applyfn__add(block.getArgument(2), value3);
yields.push_back(value4);
helper.yieldOutputs(yields);
}
LogicalResult FhelinalgConv2DNchwFchwOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
void FhelinalgConv2DNchwFchwOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
SmallVector<Value> inputBuffers = getInputBufferOperands();
SmallVector<Value> outputBuffers = getOutputBufferOperands();
getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
outputBuffers);
}
/// Verify the transpose shapes
mlir::LogicalResult TransposeOp::verify() {
mlir::Type tensorTy = ((mlir::Type)this->tensor().getType());

View File

@@ -50,7 +50,7 @@ static bool isCandidateForTask(Operation *op) {
FHELinalg::ApplyLookupTableEintOp, FHELinalg::ApplyMultiLookupTableEintOp,
FHELinalg::ApplyMappedLookupTableEintOp, FHELinalg::Dot,
FHELinalg::MatMulEintIntOp, FHELinalg::MatMulIntEintOp, FHELinalg::SumOp,
FHELinalg::ConcatOp, FHELinalg::FhelinalgConv2DNchwFchwOp>(op);
FHELinalg::ConcatOp>(op);
}
/// Identify operations that are beneficial to sink into tasks. These

View File

@@ -1,29 +0,0 @@
from mlir.dialects.linalg.opdsl.lang import *
T1 = TV.T1
T2 = TV.T2
Batch = S.Batch
@linalg_structured_op
def fhelinalg_conv_2d_nchw_fchw(
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True),
strides=AttributeDef(S.SH, S.SW),
dilations=AttributeDef(S.DH, S.DW)):
"""Performs 2-D convolution.
Layout:
* Input: NCHW.
* Kernel: FCHW.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
implements(ConvolutionOpInterface)
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
O[D.n, D.f, D.oh, D.ow] += cast(
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW
]) * cast(U, K[D.f, D.c, D.kh, D.kw])