mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat: add generated linalg conv operation
This has been generated using linalg tools, then put in their appropriate locations. This is intended as a workaround since linalg doesn't support tensors of custom types yet. Any conversion using this added operation should be able to use the default operation from linalg when it starts supporting tensor of custom types.
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
|
||||
@@ -96,6 +97,10 @@ public:
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
// TODO: remove this when removing the custom linalg op for Conv
|
||||
// the generated code was calling functions from the mlir::linalg namespace
|
||||
using namespace mlir::linalg;
|
||||
// END TODO
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc"
|
||||
|
||||
#endif
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
|
||||
include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td"
|
||||
include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.td"
|
||||
@@ -628,4 +631,153 @@ def ConcatOp : FHELinalg_Op<"concat"> {
|
||||
}];
|
||||
}
|
||||
|
||||
class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
|
||||
: Op<Linalg_Dialect, mnemonic, !listconcat([
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
LinalgStructuredInterface,
|
||||
ReifyRankedShapedTypeOpInterface], props)> {
|
||||
code structuredOpsBaseDecls = [{
|
||||
// Return whether the op accesses the iteration indices.
|
||||
bool hasIndexSemantics() {
|
||||
return !this->getBody()->getOps<IndexOp>().empty();
|
||||
}
|
||||
|
||||
LogicalResult reifyResultShapes(OpBuilder &b,
|
||||
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
||||
return cast<LinalgOp>(getOperation()).reifyResultShapes(b,
|
||||
reifiedReturnShapes);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def FhelinalgConv2DNchwFchwOp : LinalgStructuredBase_Op<"fhelinalg_conv_2d_nchw_fchw", !listconcat([AttrSizedOperandSegments],
|
||||
/*extraInterfaces=*/[LinalgConvolutionOpInterface])> {
|
||||
|
||||
let cppNamespace = "mlir::concretelang::FHELinalg";
|
||||
let summary = [{ Performs 2-D convolution. }];
|
||||
let description = [{
|
||||
Layout:
|
||||
* Input: NCHW.
|
||||
* Kernel: FCHW.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$inputs,
|
||||
Variadic<AnyShaped>:$outputs,
|
||||
RankedI64ElementsAttr<[2]>:$strides,
|
||||
RankedI64ElementsAttr<[2]>:$dilations
|
||||
);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
|
||||
let regions = (region AnyRegion:$region);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<
|
||||
(ins "ValueRange":$inputs, "ValueRange":$outputs,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addOperands(inputs);
|
||||
$_state.addOperands(outputs);
|
||||
SmallVector<Type> resultTensorTypes;
|
||||
copy_if(outputs.getTypes(),
|
||||
std::back_inserter(resultTensorTypes),
|
||||
[](Type type) { return type.isa<RankedTensorType>(); });
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
$_state.addAttribute(
|
||||
"operand_segment_sizes",
|
||||
$_builder.getI32VectorAttr({
|
||||
static_cast<int32_t>(inputs.size()),
|
||||
static_cast<int32_t>(outputs.size())}));
|
||||
$_state.addAttributes(attributes);
|
||||
createAndFillStructuredOpRegion<FhelinalgConv2DNchwFchwOp>(
|
||||
$_builder,
|
||||
$_state,
|
||||
TypeRange(inputs),
|
||||
TypeRange(outputs));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
"ValueRange":$outputs,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addOperands(inputs);
|
||||
$_state.addOperands(outputs);
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
$_state.addAttributes(attributes);
|
||||
$_state.addAttribute(
|
||||
"operand_segment_sizes",
|
||||
$_builder.getI32VectorAttr({
|
||||
static_cast<int32_t>(inputs.size()),
|
||||
static_cast<int32_t>(outputs.size())}));
|
||||
createAndFillStructuredOpRegion<FhelinalgConv2DNchwFchwOp>(
|
||||
$_builder,
|
||||
$_state,
|
||||
TypeRange(inputs),
|
||||
TypeRange(outputs));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addOperands(operands);
|
||||
$_state.addAttributes(attributes);
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
(void)$_state.addRegion();
|
||||
}]>
|
||||
|
||||
, OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
"ValueRange":$outputs, "Attribute":$strides, "Attribute":$dilations,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addOperands(inputs);
|
||||
$_state.addOperands(outputs);
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
$_state.addAttribute(
|
||||
"operand_segment_sizes",
|
||||
$_builder.getI32VectorAttr({
|
||||
static_cast<int32_t>(inputs.size()),
|
||||
static_cast<int32_t>(outputs.size())}));
|
||||
createAndFillStructuredOpRegion<FhelinalgConv2DNchwFchwOp>(
|
||||
$_builder,
|
||||
$_state,
|
||||
TypeRange(inputs),
|
||||
TypeRange(outputs));
|
||||
$_state.addAttribute("strides", strides);
|
||||
$_state.addAttribute("dilations", dilations);
|
||||
$_state.addAttributes(attributes);
|
||||
}]>
|
||||
|
||||
];
|
||||
let printer = [{ return mlir::concretelang::FHELinalg::printNamedStructuredOp(p, *this); }];
|
||||
let parser = [{
|
||||
return mlir::concretelang::FHELinalg::parseNamedStructuredOp<FhelinalgConv2DNchwFchwOp>(parser, result);
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = structuredOpsBaseDecls # [{
|
||||
// Auto-generated.
|
||||
ArrayAttr iterator_types();
|
||||
ArrayAttr indexing_maps();
|
||||
static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
|
||||
static std::function<void(ImplicitLocOpBuilder &b, Block &)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
|
||||
// Generic methods.
|
||||
static unsigned getNumRegionArgs();
|
||||
std::string getLibraryCallName();
|
||||
|
||||
bool hasDynamicIndexingMaps();
|
||||
LogicalResult verifyIndexingMapRequiredAttributes();
|
||||
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user