diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td index ef5e871cf..d3ac0ef3e 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td @@ -30,21 +30,21 @@ class TFHE_BatchableBinaryOp< { let extraClassDeclaration = [{ struct BatchingVariant { - static const unsigned BATCHED_SCALAR = 0; - static const unsigned SCALAR_BATCHED = 1; - static const unsigned ALL_BATCHED = 2; + static const unsigned ALL_BATCHED = 0; + static const unsigned BATCHED_SCALAR = 1; + static const unsigned SCALAR_BATCHED = 2; }; unsigned getNumBatchingVariants() { return 3; } ::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands(unsigned variant) { switch(variant) { + case BatchingVariant::ALL_BATCHED: + return getOperation()->getOpOperands(); case BatchingVariant::BATCHED_SCALAR: return getOperation()->getOpOperands().take_front(); case BatchingVariant::SCALAR_BATCHED: return getOperation()->getOpOperands().drop_front().take_front(); - case BatchingVariant::ALL_BATCHED: - return getOperation()->getOpOperands(); } llvm_unreachable("Unknown batching variant"); @@ -61,6 +61,12 @@ class TFHE_BatchableBinaryOp< ::llvm::SmallVector<::mlir::Value> operands; switch(variant) { + case BatchingVariant::ALL_BATCHED: + operands = batchedOperands; + return builder.create<}] # all_batched_opname # [{>( + mlir::TypeRange{resType}, + operands, + getOperation()->getAttrs()); case BatchingVariant::BATCHED_SCALAR: operands.push_back(batchedOperands[0]); operands.push_back(hoistedNonBatchableOperands[0]); @@ -75,12 +81,6 @@ class TFHE_BatchableBinaryOp< mlir::TypeRange{resType}, operands, getOperation()->getAttrs()); - case BatchingVariant::ALL_BATCHED: - operands = batchedOperands; - return builder.create<}] # all_batched_opname # [{>( - mlir::TypeRange{resType}, - operands, - getOperation()->getAttrs()); } llvm_unreachable("Unknown batching variant");