From 549d2ded86add82090c48857aa9c53e64618ba77 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Fri, 9 Jun 2023 16:57:30 +0200 Subject: [PATCH] feat(compiler): Batching: Favor batching of all operands for binary TFHE ops This changes the order of batching variants for binary TFHE operations, such that batching of both operands is favored over batching of a single operand. --- .../concretelang/Dialect/TFHE/IR/TFHEOps.td | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td index ef5e871cf..d3ac0ef3e 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td +++ b/compilers/concrete-compiler/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td @@ -30,21 +30,21 @@ class TFHE_BatchableBinaryOp< { let extraClassDeclaration = [{ struct BatchingVariant { - static const unsigned BATCHED_SCALAR = 0; - static const unsigned SCALAR_BATCHED = 1; - static const unsigned ALL_BATCHED = 2; + static const unsigned ALL_BATCHED = 0; + static const unsigned BATCHED_SCALAR = 1; + static const unsigned SCALAR_BATCHED = 2; }; unsigned getNumBatchingVariants() { return 3; } ::llvm::MutableArrayRef<::mlir::OpOperand> getBatchableOperands(unsigned variant) { switch(variant) { + case BatchingVariant::ALL_BATCHED: + return getOperation()->getOpOperands(); case BatchingVariant::BATCHED_SCALAR: return getOperation()->getOpOperands().take_front(); case BatchingVariant::SCALAR_BATCHED: return getOperation()->getOpOperands().drop_front().take_front(); - case BatchingVariant::ALL_BATCHED: - return getOperation()->getOpOperands(); } llvm_unreachable("Unknown batching variant"); @@ -61,6 +61,12 @@ class TFHE_BatchableBinaryOp< ::llvm::SmallVector<::mlir::Value> operands; switch(variant) { + case BatchingVariant::ALL_BATCHED: + operands = batchedOperands; + return builder.create<}] # all_batched_opname # [{>( + mlir::TypeRange{resType}, + operands, + getOperation()->getAttrs()); case BatchingVariant::BATCHED_SCALAR: operands.push_back(batchedOperands[0]); operands.push_back(hoistedNonBatchableOperands[0]); @@ -75,12 +81,6 @@ class TFHE_BatchableBinaryOp< mlir::TypeRange{resType}, operands, getOperation()->getAttrs()); - case BatchingVariant::ALL_BATCHED: - operands = batchedOperands; - return builder.create<}] # all_batched_opname # [{>( - mlir::TypeRange{resType}, - operands, - getOperation()->getAttrs()); } llvm_unreachable("Unknown batching variant");