mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor: remove code related to our custom conv2D named op
it was introduced as a workaround while linalng couldn't support other types than int/float
This commit is contained in:
@@ -97,10 +97,6 @@ public:
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
// TODO: remove this when removing the custom linalg op for Conv
|
||||
// the generated code was calling functions from the mlir::linalg namespace
|
||||
using namespace mlir::linalg;
|
||||
// END TODO
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h.inc"
|
||||
|
||||
#endif
|
||||
|
||||
@@ -3,9 +3,6 @@
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
|
||||
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
|
||||
include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.td"
|
||||
include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.td"
|
||||
@@ -946,151 +943,6 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
class LinalgStructuredBase_Op<string mnemonic, list<Trait> props>
|
||||
: Op<Linalg_Dialect, mnemonic, !listconcat([
|
||||
SingleBlockImplicitTerminator<"YieldOp">,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
LinalgStructuredInterface,
|
||||
ReifyRankedShapedTypeOpInterface], props)> {
|
||||
code structuredOpsBaseDecls = [{
|
||||
// Return whether the op accesses the iteration indices.
|
||||
bool hasIndexSemantics() {
|
||||
return !this->getBody()->getOps<IndexOp>().empty();
|
||||
}
|
||||
|
||||
LogicalResult reifyResultShapes(OpBuilder &b,
|
||||
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
||||
return cast<LinalgOp>(getOperation()).reifyResultShapes(b,
|
||||
reifiedReturnShapes);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def FhelinalgConv2DNchwFchwOp : LinalgStructuredBase_Op<"fhelinalg_conv_2d_nchw_fchw", !listconcat([AttrSizedOperandSegments],
|
||||
/*extraInterfaces=*/[LinalgConvolutionOpInterface])> {
|
||||
|
||||
let cppNamespace = "mlir::concretelang::FHELinalg";
|
||||
let summary = [{ Performs 2-D convolution. }];
|
||||
let description = [{
|
||||
Layout:
|
||||
* Input: NCHW.
|
||||
* Kernel: FCHW.
|
||||
|
||||
Numeric casting is performed on the operands to the inner multiply, promoting
|
||||
them to the same data type as the accumulator/output.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$inputs,
|
||||
Variadic<AnyShaped>:$outputs,
|
||||
RankedI64ElementsAttr<[2]>:$strides,
|
||||
RankedI64ElementsAttr<[2]>:$dilations
|
||||
);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
|
||||
let regions = (region AnyRegion:$region);
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<
|
||||
(ins "ValueRange":$inputs, "ValueRange":$outputs,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addOperands(inputs);
|
||||
$_state.addOperands(outputs);
|
||||
SmallVector<Type> resultTensorTypes;
|
||||
copy_if(outputs.getTypes(),
|
||||
std::back_inserter(resultTensorTypes),
|
||||
[](Type type) { return type.isa<RankedTensorType>(); });
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
$_state.addAttribute(
|
||||
"operand_segment_sizes",
|
||||
$_builder.getI32VectorAttr({
|
||||
static_cast<int32_t>(inputs.size()),
|
||||
static_cast<int32_t>(outputs.size())}));
|
||||
$_state.addAttributes(attributes);
|
||||
createAndFillStructuredOpRegion<FhelinalgConv2DNchwFchwOp>(
|
||||
$_builder,
|
||||
$_state,
|
||||
TypeRange(inputs),
|
||||
TypeRange(outputs));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
"ValueRange":$outputs,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addOperands(inputs);
|
||||
$_state.addOperands(outputs);
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
$_state.addAttributes(attributes);
|
||||
$_state.addAttribute(
|
||||
"operand_segment_sizes",
|
||||
$_builder.getI32VectorAttr({
|
||||
static_cast<int32_t>(inputs.size()),
|
||||
static_cast<int32_t>(outputs.size())}));
|
||||
createAndFillStructuredOpRegion<FhelinalgConv2DNchwFchwOp>(
|
||||
$_builder,
|
||||
$_state,
|
||||
TypeRange(inputs),
|
||||
TypeRange(outputs));
|
||||
}]>,
|
||||
OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addOperands(operands);
|
||||
$_state.addAttributes(attributes);
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
(void)$_state.addRegion();
|
||||
}]>
|
||||
|
||||
, OpBuilder<
|
||||
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
||||
"ValueRange":$outputs, "Attribute":$strides, "Attribute":$dilations,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||
[{
|
||||
$_state.addOperands(inputs);
|
||||
$_state.addOperands(outputs);
|
||||
$_state.addTypes(resultTensorTypes);
|
||||
$_state.addAttribute(
|
||||
"operand_segment_sizes",
|
||||
$_builder.getI32VectorAttr({
|
||||
static_cast<int32_t>(inputs.size()),
|
||||
static_cast<int32_t>(outputs.size())}));
|
||||
createAndFillStructuredOpRegion<FhelinalgConv2DNchwFchwOp>(
|
||||
$_builder,
|
||||
$_state,
|
||||
TypeRange(inputs),
|
||||
TypeRange(outputs));
|
||||
$_state.addAttribute("strides", strides);
|
||||
$_state.addAttribute("dilations", dilations);
|
||||
$_state.addAttributes(attributes);
|
||||
}]>
|
||||
|
||||
];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = structuredOpsBaseDecls # [{
|
||||
// Auto-generated.
|
||||
ArrayAttr iterator_types();
|
||||
ArrayAttr indexing_maps();
|
||||
static void regionBuilder(ImplicitLocOpBuilder &b, Block &block, llvm::ArrayRef<mlir::NamedAttribute>);
|
||||
static std::function<void(ImplicitLocOpBuilder &b, Block &, llvm::ArrayRef<mlir::NamedAttribute>)>
|
||||
getRegionBuilder() {
|
||||
return regionBuilder;
|
||||
}
|
||||
|
||||
// Generic methods.
|
||||
static unsigned getNumRegionArgs();
|
||||
std::string getLibraryCallName();
|
||||
|
||||
bool hasDynamicIndexingMaps();
|
||||
LogicalResult verifyIndexingMapRequiredAttributes();
|
||||
|
||||
}];
|
||||
}
|
||||
|
||||
def FHELinalg_TransposeOp : FHELinalg_Op<"transpose", []> {
|
||||
let summary = "Returns a tensor that contains the transposition of the input tensor.";
|
||||
|
||||
|
||||
Reference in New Issue
Block a user