// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. #include #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/TypeUtilities.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h" namespace mlir { namespace OpTrait { namespace impl { LogicalResult verifyTensorBroadcastingRules( mlir::Operation *op, llvm::SmallVector operands, mlir::RankedTensorType result) { llvm::SmallVector> operandsShapes; size_t maxOperandsDim = 0; auto resultShape = result.getShape(); for (size_t i = 0; i < operands.size(); i++) { auto shape = operands[i].getShape(); operandsShapes.push_back(shape); maxOperandsDim = std::max(shape.size(), maxOperandsDim); } // Check the result has the same number of dimension than the highest // dimension of operands if (resultShape.size() != maxOperandsDim) { op->emitOpError() << "should have the number of dimensions of the result equal to the " "highest number of dimensions of operands" << ", got " << result.getShape().size() << " expect " << maxOperandsDim; return mlir::failure(); } // For all dimension for (size_t i = 0; i < maxOperandsDim; i++) { int64_t expectedResultDim = 1; // Check the dimension of operands shape are compatible, i.e. equals or 1 for (size_t j = 0; j < operandsShapes.size(); j++) { if (i < maxOperandsDim - operandsShapes[j].size()) { continue; } auto k = i - (maxOperandsDim - operandsShapes[j].size()); auto operandDim = operandsShapes[j][k]; if (expectedResultDim != 1 && operandDim != 1 && operandDim != expectedResultDim) { op->emitOpError() << "has the dimension #" << (operandsShapes[j].size() - k) << " of the operand #" << j << " incompatible with other operands" << ", got " << operandDim << " expect 1 or " << expectedResultDim; return mlir::failure(); } expectedResultDim = std::max(operandDim, expectedResultDim); } // Check the dimension of the result is compatible with dimesion of the // operands if (resultShape[i] != expectedResultDim) { op->emitOpError() << "has the dimension #" << (maxOperandsDim - i) << " of the result incompatible with operands dimension" << ", got " << resultShape[i] << " expect " << expectedResultDim; return mlir::failure(); } } return mlir::success(); } LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op) { // Check operands type are ranked tensor llvm::SmallVector tensorOperands; unsigned i = 0; for (auto opType : op->getOperandTypes()) { auto tensorType = opType.dyn_cast_or_null(); if (tensorType == nullptr) { op->emitOpError() << " should have a ranked tensor as operand #" << i; return mlir::failure(); } tensorOperands.push_back(tensorType); i++; } // Check number of result is 1 if (op->getNumResults() != 1) { op->emitOpError() << "should have exactly 1 result, got " << op->getNumResults(); } auto tensorResult = op->getResult(0).getType().dyn_cast_or_null(); if (tensorResult == nullptr) { op->emitOpError(llvm::Twine("should have a ranked tensor as result")); return mlir::failure(); } return verifyTensorBroadcastingRules(op, tensorOperands, tensorResult); } LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op) { if (op->getNumOperands() != 2) { op->emitOpError() << "should have exactly 2 operands"; return mlir::failure(); } auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); if (op0Ty == nullptr || op1Ty == nullptr) { op->emitOpError() << "should have both operands as tensor"; return mlir::failure(); } auto el0Ty = op0Ty.getElementType() .dyn_cast_or_null(); if (el0Ty == nullptr) { op->emitOpError() << "should have a !FHE.eint as the element type of the " "tensor of operand #0"; return mlir::failure(); } auto el1Ty = op1Ty.getElementType().dyn_cast_or_null(); if (el1Ty == nullptr) { op->emitOpError() << "should have an integer as the element type of the " "tensor of operand #1"; return mlir::failure(); } if (el1Ty.getWidth() > el0Ty.getWidth() + 1) { op->emitOpError() << "should have the width of integer values less or equals " "than the width of encrypted values + 1"; return mlir::failure(); } return mlir::success(); } LogicalResult verifyTensorBinaryIntEint(mlir::Operation *op) { if (op->getNumOperands() != 2) { op->emitOpError() << "should have exactly 2 operands"; return mlir::failure(); } auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); if (op0Ty == nullptr || op1Ty == nullptr) { op->emitOpError() << "should have both operands as tensor"; return mlir::failure(); } auto el0Ty = op0Ty.getElementType().dyn_cast_or_null(); if (el0Ty == nullptr) { op->emitOpError() << "should have an integer as the element type of the " "tensor of operand #0"; return mlir::failure(); } auto el1Ty = op1Ty.getElementType() .dyn_cast_or_null(); if (el1Ty == nullptr) { op->emitOpError() << "should have a !FHE.eint as the element type of the " "tensor of operand #1"; return mlir::failure(); } if (el1Ty.getWidth() > el0Ty.getWidth() + 1) { op->emitOpError() << "should have the width of integer values less or equals " "than the width of encrypted values + 1"; return mlir::failure(); } return mlir::success(); } LogicalResult verifyTensorBinaryEint(mlir::Operation *op) { if (op->getNumOperands() != 2) { op->emitOpError() << "should have exactly 2 operands"; return mlir::failure(); } auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); if (op0Ty == nullptr || op1Ty == nullptr) { op->emitOpError() << "should have both operands as tensor"; return mlir::failure(); } auto el0Ty = op0Ty.getElementType() .dyn_cast_or_null(); if (el0Ty == nullptr) { op->emitOpError() << "should have a !FHE.eint as the element type of the " "tensor of operand #0"; return mlir::failure(); } auto el1Ty = op1Ty.getElementType() .dyn_cast_or_null(); if (el1Ty == nullptr) { op->emitOpError() << "should have a !FHE.eint as the element type of the " "tensor of operand #1"; return mlir::failure(); } if (el1Ty.getWidth() != el0Ty.getWidth()) { op->emitOpError() << "should have the width of encrypted equals" ", got " << el1Ty.getWidth() << " expect " << el0Ty.getWidth(); return mlir::failure(); } return mlir::success(); } LogicalResult verifyTensorUnaryEint(mlir::Operation *op) { if (op->getNumOperands() != 1) { op->emitOpError() << "should have exactly 1 operands"; return mlir::failure(); } auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); if (op0Ty == nullptr) { op->emitOpError() << "should have operand as tensor"; return mlir::failure(); } auto el0Ty = op0Ty.getElementType() .dyn_cast_or_null(); if (el0Ty == nullptr) { op->emitOpError() << "should have a !FHE.eint as the element type of the " "tensor operand"; return mlir::failure(); } return mlir::success(); } } // namespace impl } // namespace OpTrait } // namespace mlir namespace mlir { namespace concretelang { namespace FHELinalg { mlir::LogicalResult ApplyLookupTableEintOp::verify() { auto tTy = this->t().getType().cast(); auto tEltTy = tTy.getElementType() .cast(); auto lutTy = this->lut().getType().cast(); auto lutEltTy = lutTy.getElementType().cast(); auto resultTy = this->getResult().getType().cast(); // Check the shape of lut argument auto tEltwidth = tEltTy.getWidth(); mlir::SmallVector expectedShape{1 << tEltwidth}; if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isInteger(64)) { this->emitOpError() << "should have as operand #2 a tensor<2^pxi64>, where p is the width " "of the encrypted integer of the operand #1," << "expect tensor <" << expectedShape[0] << "xi64>"; return mlir::failure(); } if (!resultTy.hasStaticShape(tTy.getShape())) { this->emitOpError() << " should have same shapes for operand #1 and the result"; } return mlir::success(); } mlir::LogicalResult ApplyMultiLookupTableEintOp::verify() { auto tTy = this->t().getType().cast(); auto tEltTy = tTy.getElementType() .cast(); auto lutTy = this->luts().getType().cast(); auto lutEltTy = lutTy.getElementType().cast(); auto resultTy = this->getResult().getType().cast(); // Check the shape of luts argument auto lut_size = lutTy.getShape()[lutTy.getShape().size() - 1]; auto expected_lut_size = 1 << tEltTy.getWidth(); if (lut_size != expected_lut_size || !lutEltTy.isInteger(64)) { this->emitOpError() << "should have as operand #2 a " "tensor, where p is the width " "of the encrypted integer of the operand #1," << "expect tensor "; return mlir::failure(); } if (!resultTy.hasStaticShape(tTy.getShape())) { this->emitOpError() << " should have same shapes for operand #1 and the result"; } return mlir::success(); } mlir::RankedTensorType getTensorType(::mlir::Value value) { return value.getType().cast(); } template T getElmentType(::mlir::Value value) { auto tTy = getTensorType(value); return tTy.getElementType().cast(); } mlir::IntegerType getClearElmentType(::mlir::Value value) { return getElmentType(value); } FHE::EncryptedIntegerType getEncryptedElmentType(::mlir::Value value) { using namespace mlir::concretelang::FHE; return getElmentType(value); } mlir::LogicalResult verifyMapHasRightShape(ApplyMappedLookupTableEintOp &op, ::mlir::Value &lut_input, ::mlir::Value &lut_map) { auto input_shape = getTensorType(lut_input).getShape(); auto map_shape = getTensorType(lut_map).getShape(); if (input_shape.equals(map_shape)) { return mlir::success(); } std::string error; int input_rank = input_shape.size(); int map_rank = map_shape.size(); std::string input_name = "'t' (operand #1)"; std::string map_name = "'lut_map.getName()' (operand #3)"; if (input_rank == map_rank) { error = ": " + input_name + " dimensions differs from " + map_name; } else { error = ": " + input_name + " rank (=" + std::to_string(input_rank) + ") differs from " + map_name + " rank (=" + std::to_string(map_rank) + ")"; } op.emitOpError() << error; return mlir::failure(); } mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op, ::mlir::Value &encryptedIndex, ::mlir::Value &luts) { auto index_width = getEncryptedElmentType(encryptedIndex).getWidth(); auto actual_lut_size = getTensorType(luts).getShape().back(); auto expected_lut_size = 1 << index_width; if (actual_lut_size == expected_lut_size) { return mlir::success(); } FHE::emitErrorBadLutSize(op, "luts", "ct", expected_lut_size, index_width); return mlir::failure(); } mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() { auto t = this->t(); auto luts = this->luts(); auto map = this->map(); auto result = this->getResult(); auto t_shape = getTensorType(t).getShape(); if (!getTensorType(result).hasStaticShape(t_shape)) { this->emitOpError() << ": `t` (operand #1) and `map` (operand #2) must have the same shape"; return mlir::failure(); } if (!getTensorType(map).getElementType().isIndex()) { this->emitOpError() << ": `map` (operand #3) should contains elements of type `index`"; return mlir::failure(); } return mlir::success(verifyMapHasRightShape(*this, t, map).succeeded() && verifyLutsSize(*this, t, luts).succeeded()); } ::mlir::LogicalResult Dot::verify() { if (::mlir::failed(mlir::verifyCompatibleShape(this->lhs().getType(), this->rhs().getType()))) { return this->emitOpError("arguments have incompatible shapes"); } auto lhsEltType = this->lhs() .getType() .cast() .getElementType() .cast(); auto rhsEltType = this->rhs() .getType() .cast() .getElementType() .cast(); auto resultType = this->getResult().getType().cast(); if (!mlir::concretelang::FHE:: verifyEncryptedIntegerAndIntegerInputsConsistency( *this->getOperation(), lhsEltType, rhsEltType)) { return ::mlir::failure(); } if (!FHE::verifyEncryptedIntegerInputAndResultConsistency( *this->getOperation(), lhsEltType, resultType)) { return ::mlir::failure(); } return ::mlir::success(); } llvm::SmallVector verifySumCalculateActualOutputShape(mlir::Type outputType) { auto actualOutputShape = llvm::SmallVector{}; if (outputType.isa()) { auto outputTensorType = outputType.dyn_cast(); for (int64_t size : outputTensorType.getShape()) { actualOutputShape.push_back(size); } } return actualOutputShape; } llvm::SmallVector verifySumCalculateExpectedOutputShape( llvm::ArrayRef inputShape, int64_t inputDimensions, std::unordered_set &axesToDestroy, bool keepDims) { auto expectedOutputShape = llvm::SmallVector{}; for (int64_t i = 0; i < inputDimensions; i++) { bool ithAxisIsDestroyed = axesToDestroy.find(i) != axesToDestroy.end(); if (!ithAxisIsDestroyed) { expectedOutputShape.push_back(inputShape[i]); } else if (keepDims) { expectedOutputShape.push_back(1); } } return expectedOutputShape; } mlir::LogicalResult SumOp::verify() { mlir::Value input = this->getOperand(); mlir::Value output = this->getResult(); auto inputType = input.getType().dyn_cast(); mlir::Type outputType = output.getType(); FHE::EncryptedIntegerType inputElementType = inputType.getElementType().dyn_cast(); FHE::EncryptedIntegerType outputElementType = !outputType.isa() ? outputType.dyn_cast() : outputType.dyn_cast() .getElementType() .dyn_cast(); if (!FHE::verifyEncryptedIntegerInputAndResultConsistency( *this->getOperation(), inputElementType, outputElementType)) { return mlir::failure(); } llvm::ArrayRef inputShape = inputType.getShape(); int64_t inputDimensions = (int64_t)inputShape.size(); mlir::ArrayAttr axes = this->axes(); bool keepDims = this->keep_dims(); auto axesToDestroy = std::unordered_set{}; for (mlir::Attribute axisAttribute : axes) { int64_t axis = axisAttribute.cast().getInt(); bool axisIsValid = (0 <= axis) && (axis < inputDimensions); if (!axisIsValid) { this->emitOpError("has invalid axes attribute"); return mlir::failure(); } axesToDestroy.insert(axis); } if (axesToDestroy.empty()) { for (int64_t i = 0; i < inputDimensions; i++) { axesToDestroy.insert(i); } } auto expectedOutputShape = verifySumCalculateExpectedOutputShape( inputShape, inputDimensions, axesToDestroy, keepDims); auto actualOutputShape = verifySumCalculateActualOutputShape(outputType); if (expectedOutputShape != actualOutputShape) { auto stream = this->emitOpError(); stream << "does not have the proper output shape of <"; if (!expectedOutputShape.empty()) { stream << expectedOutputShape[0]; for (size_t i = 1; i < expectedOutputShape.size(); i++) { stream << "x" << expectedOutputShape[i]; } } stream << ">"; return mlir::failure(); } return mlir::success(); } static bool sameShapeExceptAxis(llvm::ArrayRef shape1, llvm::ArrayRef shape2, size_t axis) { if (shape1.size() != shape2.size()) { return false; } for (size_t i = 0; i < shape1.size(); i++) { if (i != axis && shape1[i] != shape2[i]) { return false; } } return true; } mlir::LogicalResult ConcatOp::verify() { unsigned numOperands = this->getNumOperands(); if (numOperands < 2) { this->emitOpError() << "should have at least 2 inputs"; return mlir::failure(); } int64_t axis = this->axis(); mlir::Value out = this->out(); auto outVectorType = out.getType().dyn_cast(); auto outElementType = outVectorType.getElementType().dyn_cast(); llvm::ArrayRef outShape = outVectorType.getShape(); size_t outDims = outShape.size(); if (axis < 0 || (size_t)axis >= outDims) { this->emitOpError() << "has invalid axis attribute"; return mlir::failure(); } int64_t expectedOutputElementsInAxis = 0; size_t index = 0; for (mlir::Value in : this->ins()) { auto inVectorType = in.getType().dyn_cast(); auto inElementType = inVectorType.getElementType().dyn_cast(); if (!FHE::verifyEncryptedIntegerInputAndResultConsistency( *this->getOperation(), inElementType, outElementType)) { return ::mlir::failure(); } llvm::ArrayRef inShape = inVectorType.getShape(); if (!sameShapeExceptAxis(inShape, outShape, (size_t)axis)) { auto stream = this->emitOpError(); stream << "does not have the proper shape of <"; if (axis == 0) { stream << "?"; } else { stream << outShape[0]; } for (size_t i = 1; i < outDims; i++) { stream << "x"; if (i == (size_t)axis) { stream << "?"; } else { stream << outShape[i]; } } stream << "> for input #" << index; return mlir::failure(); } expectedOutputElementsInAxis += inShape[axis]; index += 1; } if (outShape[axis] != expectedOutputElementsInAxis) { auto stream = this->emitOpError(); stream << "does not have the proper output shape of <"; if (axis == 0) { stream << expectedOutputElementsInAxis; } else { stream << outShape[0]; } for (size_t i = 1; i < outDims; i++) { stream << "x"; if (i == (size_t)axis) { stream << expectedOutputElementsInAxis; } else { stream << outShape[i]; } } stream << ">"; return mlir::failure(); } return mlir::success(); } /// Verify the matmul shapes, the type of tensor elements should be checked by /// something else template mlir::LogicalResult verifyMatmul(MatMulOp &op) { auto lhsType = ((mlir::Type)op.lhs().getType()).cast(); auto rhsType = ((mlir::Type)op.rhs().getType()).cast(); llvm::ArrayRef lhsShape = lhsType.getShape(); llvm::ArrayRef rhsShape = rhsType.getShape(); int64_t lhsDims = (int64_t)lhsShape.size(); int64_t rhsDims = (int64_t)rhsShape.size(); auto expectedOutputShape = mlir::SmallVector{}; if (lhsDims == 2 && rhsDims == 2) { // MxN @ NxP -> MxP if (lhsShape[1] != rhsShape[0]) { op.emitOpError() << "should have the same size " "on dimension #1 of operand #0 " "and dimension #0 of operand #1"; return mlir::failure(); } expectedOutputShape.push_back(lhsShape[0]); expectedOutputShape.push_back(rhsShape[1]); } else if (lhsDims >= 2 && rhsDims >= 2) { // KxLxMxN @ NxP -> KxLxMxP // KxLxMxN @ LxNxP -> KxLxMxP // Kx1xMxN @ LxNxP -> KxLxMxP // MxN @ KxLxNxP -> KxLxMxP // LxMxN @ KxLxNxP -> KxLxMxP // 1xMxN @ KxLxNxP -> KxLxMxP if (lhsShape[lhsDims - 1] != rhsShape[rhsDims - 2]) { op.emitOpError() << "should have the same size " << "on dimension #" << lhsDims - 1 << " of operand #0 " << "and dimension #" << rhsDims - 2 << " of operand #1"; return mlir::failure(); } auto expectedOutputShapeReversed = mlir::SmallVector{}; expectedOutputShapeReversed.push_back(rhsShape[rhsDims - 1]); expectedOutputShapeReversed.push_back(lhsShape[lhsDims - 2]); int64_t i = lhsDims - 3; int64_t j = rhsDims - 3; while (i >= 0 && j >= 0) { int64_t lhsSize = lhsShape[i]; int64_t rhsSize = rhsShape[j]; if (lhsSize == rhsSize || lhsSize == 1 || rhsSize == 1) { expectedOutputShapeReversed.push_back(std::max(lhsSize, rhsSize)); } else { op.emitOpError() << "should have the same size or size of 1 " << "on dimension #" << i << " of operand #0 " << "and dimension #" << j << " of operand #1"; return mlir::failure(); } i--; j--; } while (i >= 0) { int64_t lhsSize = lhsShape[i]; expectedOutputShapeReversed.push_back(lhsSize); i--; } while (j >= 0) { int64_t rhsSize = rhsShape[j]; expectedOutputShapeReversed.push_back(rhsSize); j--; } while (!expectedOutputShapeReversed.empty()) { expectedOutputShape.push_back(expectedOutputShapeReversed.back()); expectedOutputShapeReversed.pop_back(); } } else if (lhsDims == 1 && rhsDims >= 2) { // N @ NxP -> P // N @ LxNxP -> LxP // N @ KxLxNxP -> KxLxP if (rhsShape[rhsDims - 2] != lhsShape[0]) { op.emitOpError() << "should have the same size " << "on dimension #0 of operand #0 " << "and dimension #" << rhsDims - 2 << " of operand #1"; return mlir::failure(); } for (int64_t i = 0; i < rhsDims; i++) { if (i != rhsDims - 2) { expectedOutputShape.push_back(rhsShape[i]); } } } else if (lhsDims >= 2 && rhsDims == 1) { // MxN @ N -> M // LxMxN @ N -> LxM // KxLxMxN @ N -> KxLxM if (lhsShape[lhsDims - 1] != rhsShape[0]) { op.emitOpError() << "should have the same size " << "on dimension #" << lhsDims - 1 << " of operand #0 " << "and dimension #0 of operand #1"; return mlir::failure(); } for (int64_t i = 0; i < lhsDims - 1; i++) { expectedOutputShape.push_back(lhsShape[i]); } } else { // M @ N op.emitOpError() << "should have at least one " "multi dimensional tensor " "as an operand"; return mlir::failure(); } auto resultType = ((mlir::Type)op.getResult().getType()).cast(); if (!resultType.hasStaticShape(expectedOutputShape)) { auto stream = op->emitOpError(); stream << "does not have the proper output shape of "; stream << "<" << expectedOutputShape[0]; for (size_t i = 1; i < expectedOutputShape.size(); i++) { stream << "x" << expectedOutputShape[i]; } stream << ">"; return mlir::failure(); } return mlir::success(); } mlir::LogicalResult MatMulEintIntOp::verify() { return ::mlir::concretelang::FHELinalg::verifyMatmul< mlir::concretelang::FHELinalg::MatMulEintIntOp>(*this); } mlir::LogicalResult MatMulIntEintOp::verify() { return ::mlir::concretelang::FHELinalg::verifyMatmul< mlir::concretelang::FHELinalg::MatMulIntEintOp>(*this); } mlir::SmallVector getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { mlir::SmallVector paddingInts; llvm::Optional optionalPadding = convOp.padding(); if (optionalPadding.hasValue()) { auto paddingAttr = optionalPadding.getValue(); auto paddingAttrShape = paddingAttr.getType().cast().getShape(); assert(paddingAttrShape.size() == 1 && paddingAttrShape[0] == 4 && "incorrect padding shape"); paddingInts.insert(paddingInts.begin(), paddingAttr.value_begin(), paddingAttr.value_end()); } else { paddingInts.insert(paddingInts.begin(), {0, 0, 0, 0}); } return paddingInts; } mlir::SmallVector getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { mlir::SmallVector stridesInts; llvm::Optional optionalStrides = convOp.strides(); if (optionalStrides.hasValue()) { auto stridesAttr = optionalStrides.getValue(); auto stridesAttrShape = stridesAttr.getType().cast().getShape(); assert(stridesAttrShape.size() == 1 && stridesAttrShape[0] == 2 && "incorrect strides shape"); stridesInts.insert(stridesInts.begin(), stridesAttr.value_begin(), stridesAttr.value_end()); } else { stridesInts.insert(stridesInts.begin(), {1, 1}); } return stridesInts; } mlir::SmallVector getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { mlir::SmallVector dilationsInts; llvm::Optional optionalDilations = convOp.dilations(); if (optionalDilations.hasValue()) { auto dilationsAttr = optionalDilations.getValue(); auto dilationsAttrShape = dilationsAttr.getType().cast().getShape(); assert(dilationsAttrShape.size() == 1 && dilationsAttrShape[0] == 2 && "incorrect dilations shape"); dilationsInts.insert(dilationsInts.begin(), dilationsAttr.value_begin(), dilationsAttr.value_end()); } else { dilationsInts.insert(dilationsInts.begin(), {1, 1}); } return dilationsInts; } int64_t getGroupFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { llvm::Optional optionalGroup = convOp.group(); if (optionalGroup.hasValue()) return optionalGroup.getValue(); return 1; } /// Verify the Conv2d shapes, attributes, and expected output dimensions mlir::LogicalResult Conv2dOp::verify() { auto inputTy = ((mlir::Type)this->input().getType()).cast(); auto weightTy = ((mlir::Type)this->weight().getType()).cast(); auto resultTy = ((mlir::Type)this->getResult().getType()).cast(); auto inputShape = inputTy.getShape(); auto weightShape = weightTy.getShape(); auto resultShape = resultTy.getShape(); auto p = inputTy.getElementType() .cast() .getWidth(); auto weightElementTyWidth = weightTy.getElementType().cast().getWidth(); if (weightElementTyWidth != p + 1) { this->emitOpError() << "expected weight element type to have width " << p + 1 << " but got " << weightElementTyWidth; return mlir::failure(); } // Checking dimensions if (inputShape.size() != 4) { this->emitOpError() << "input should have 4 dimensions (N*C*H*W) but got " << inputShape.size(); return mlir::failure(); } if (weightShape.size() != 4) { this->emitOpError() << "weight should have 4 dimensions (F*C*H*W) but got " << weightShape.size(); return mlir::failure(); } if (resultShape.size() != 4) { this->emitOpError() << "result should have 4 dimensions (N*C*H*W) but got " << resultShape.size(); return mlir::failure(); } // Checking attributes mlir::SmallVector paddingInts = getPaddingFromConv2d(*this); llvm::Optional optionalPadding = this->padding(); if (optionalPadding.hasValue()) { auto paddingAttr = optionalPadding.getValue(); auto paddingAttrShape = paddingAttr.getType().cast().getShape(); if (paddingAttrShape.size() != 1 || paddingAttrShape[0] != 4) { this->emitOpError() << "padding should have a single dimension of size 4, but got shape [" << paddingAttrShape << "]"; return mlir::failure(); } for (auto i = 0; i < 4; i++) { // TODO: Support padding (#427) if (paddingInts[i] != 0) { this->emitOpError() << "padding isn't yet supported, but got a non zero value (" << paddingInts[i] << ") at index " << i; return mlir::failure(); } if (paddingInts[i] < 0) { this->emitOpError() << "padding can't have a negative value, but got " << paddingInts[i] << " at index " << i; return mlir::failure(); } } } mlir::SmallVector stridesInts = getStridesFromConv2d(*this); llvm::Optional optionalStrides = this->strides(); if (optionalStrides.hasValue()) { auto stridesAttr = optionalStrides.getValue(); auto stridesAttrShape = stridesAttr.getType().cast().getShape(); if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) { this->emitOpError() << "strides should have a single dimension of size 2, but got shape [" << stridesAttrShape << "]"; return mlir::failure(); } for (auto i = 0; i < 2; i++) { if (stridesInts[i] < 1) { this->emitOpError() << "strides can't have a value less than 1, but got " << stridesInts[i] << " at index " << i; return mlir::failure(); } } } mlir::SmallVector dilationsInts = getDilationsFromConv2d(*this); llvm::Optional optionalDilations = this->dilations(); if (optionalDilations.hasValue()) { auto dilationsAttr = optionalDilations.getValue(); auto dilationsAttrShape = dilationsAttr.getType().cast().getShape(); if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) { this->emitOpError() << "dilations should have a single dimension of " "size 2, but got shape [" << dilationsAttrShape << "]"; return mlir::failure(); } for (auto i = 0; i < 2; i++) { if (dilationsInts[i] < 1) { this->emitOpError() << "dilations can't have a value less than 1, but got " << dilationsInts[i] << " at index " << i; return mlir::failure(); } } } int64_t group = getGroupFromConv2d(*this); if (group < 1) { this->emitOpError() << "group must be strictly positif, but got " << group; return mlir::failure(); } // Extracting dimensions int64_t inputN = inputShape[0], inputC = inputShape[1], inputH = inputShape[2], inputW = inputShape[3]; int64_t weightF = weightShape[0], weightC = weightShape[1], weightH = weightShape[2], weightW = weightShape[3]; int64_t resultN = resultShape[0], resultC = resultShape[1], resultH = resultShape[2], resultW = resultShape[3]; // Bias check if specified mlir::Value bias = this->bias(); if (bias) { auto biasTy = ((mlir::Type)bias.getType()).cast(); auto biasShape = biasTy.getShape(); if (biasShape.size() != 1) { this->emitOpError() << "bias should have 1 dimension but got " << biasShape.size(); return mlir::failure(); } if (biasShape[0] != weightF) { this->emitOpError() << "expected bias vector to have size " << weightF << " but got " << biasShape[0]; return mlir::failure(); } auto biasElementTyWidth = biasTy.getElementType().cast().getWidth(); if (biasElementTyWidth != p + 1) { this->emitOpError() << "expected bias element type to have width " << p + 1 << " but got " << biasElementTyWidth; return mlir::failure(); } } // Dimension sizes checks if (resultN != inputN) { this->emitOpError() << "expected result batch size to be equal to input batch size (" << inputN << ") but got " << resultN; return mlir::failure(); } if (weightC != inputC / group) { this->emitOpError() << "expected number of channels in weight to be equal to " << inputC / group << " (input_channels / group) but got " << weightC; return mlir::failure(); } if (weightF % group != 0) { this->emitOpError() << "expected number of feature maps (" << weightF << ") to be a multiple of group (" << group << ")"; return mlir::failure(); } if (weightF != resultC) { this->emitOpError() << "expected number of output channels to be equal to " "the number of filters (" << weightF << ") but got " << resultC; return mlir::failure(); } int64_t paddingH = paddingInts[0] + paddingInts[2]; int64_t paddingW = paddingInts[1] + paddingInts[3]; int64_t dilationH = dilationsInts[0]; int64_t dilationW = dilationsInts[1]; int64_t strideH = stridesInts[0]; int64_t strideW = stridesInts[1]; int64_t expectedResultH = floor((inputH + paddingH - dilationH * (weightH - 1) - 1) / strideH) + 1; int64_t expectedResultW = floor((inputW + paddingW - dilationW * (weightW - 1) - 1) / strideW) + 1; if (expectedResultH != resultH) { this->emitOpError() << "expected height of output to be equal to " << expectedResultH << " but got " << resultH; return mlir::failure(); } if (expectedResultW != resultW) { this->emitOpError() << "expected width of output to be equal to " << expectedResultW << " but got " << resultW; return mlir::failure(); } return mlir::success(); } mlir::LogicalResult FromElementOp::verify() { mlir::Value in = this->getOperand(); mlir::Value out = this->getResult(); auto inType = in.getType(); auto outType = out.getType().dyn_cast(); auto expectedOutType = outType.cloneWith({1}, inType); if (outType != expectedOutType) { this->emitOpError() << "has invalid output type (expected " << expectedOutType << ", got " << outType << ")"; return mlir::failure(); } return mlir::success(); } /// Verify the transpose shapes mlir::LogicalResult TransposeOp::verify() { mlir::Type tensorTy = ((mlir::Type)this->tensor().getType()); if (!tensorTy.isa()) { this->emitOpError() << "should have operand as tensor"; return mlir::failure(); } mlir::Type resultTy = ((mlir::Type)this->getResult().getType()); if (!resultTy.isa()) { this->emitOpError() << "should have result as tensor"; return mlir::failure(); } auto tensorShapedTy = tensorTy.dyn_cast_or_null(); auto resultShapedTy = resultTy.dyn_cast_or_null(); if (tensorShapedTy.getShape().size() != resultShapedTy.getShape().size()) { this->emitOpError() << "input and output tensors should have the same number of dimensions"; return mlir::failure(); } if (tensorShapedTy.getElementType() != resultShapedTy.getElementType()) { this->emitOpError() << "input and output tensors should have the same element type"; return mlir::failure(); } size_t n_dims = tensorShapedTy.getShape().size(); for (size_t i = 0; i < n_dims; i++) { if (tensorShapedTy.getDimSize(i) != resultShapedTy.getDimSize(n_dims - (i + 1))) { this->emitOpError() << "output tensor should have inverted dimensions of input"; return mlir::failure(); } } return mlir::success(); } /// Avoid addition with constant tensor of 0s OpFoldResult AddEintIntOp::fold(ArrayRef operands) { assert(operands.size() == 2); auto toAdd = operands[1].dyn_cast_or_null(); if (toAdd == nullptr) return nullptr; for (auto it = toAdd.begin(); it != toAdd.end(); it++) { if (*it != 0) { return nullptr; } } return getOperand(0); } /// Avoid subtraction with constant tensor of 0s OpFoldResult SubEintIntOp::fold(ArrayRef operands) { assert(operands.size() == 2); auto toSub = operands[1].dyn_cast_or_null(); if (toSub == nullptr) return nullptr; for (auto it = toSub.begin(); it != toSub.end(); it++) { if (*it != 0) { return nullptr; } } return getOperand(0); } /// Avoid multiplication with constant tensor of 1s OpFoldResult MulEintIntOp::fold(ArrayRef operands) { assert(operands.size() == 2); auto toMul = operands[1].dyn_cast_or_null(); if (toMul == nullptr) return nullptr; for (auto it = toMul.begin(); it != toMul.end(); it++) { if (*it != 1) { return nullptr; } } return getOperand(0); } } // namespace FHELinalg } // namespace concretelang } // namespace mlir #define GET_OP_CLASSES #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.cpp.inc"