// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. #include #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Parser.h" #include "llvm/Support/FormatVariadic.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h" namespace mlir { namespace OpTrait { namespace impl { LogicalResult verifyTensorBroadcastingRules( mlir::Operation *op, llvm::SmallVector operands, mlir::RankedTensorType result) { llvm::SmallVector> operandsShapes; size_t maxOperandsDim = 0; auto resultShape = result.getShape(); for (size_t i = 0; i < operands.size(); i++) { auto shape = operands[i].getShape(); operandsShapes.push_back(shape); maxOperandsDim = std::max(shape.size(), maxOperandsDim); } // Check the result has the same number of dimension than the highest // dimension of operands if (resultShape.size() != maxOperandsDim) { op->emitOpError() << "should have the number of dimensions of the result equal to the " "highest number of dimensions of operands" << ", got " << result.getShape().size() << " expect " << maxOperandsDim; return mlir::failure(); } // For all dimension for (size_t i = 0; i < maxOperandsDim; i++) { int64_t expectedResultDim = 1; // Check the dimension of operands shape are compatible, i.e. equals or 1 for (size_t j = 0; j < operandsShapes.size(); j++) { if (i < maxOperandsDim - operandsShapes[j].size()) { continue; } auto k = i - (maxOperandsDim - operandsShapes[j].size()); auto operandDim = operandsShapes[j][k]; if (expectedResultDim != 1 && operandDim != 1 && operandDim != expectedResultDim) { op->emitOpError() << "has the dimension #" << (operandsShapes[j].size() - k) << " of the operand #" << j << " incompatible with other operands" << ", got " << operandDim << " expect 1 or " << expectedResultDim; return mlir::failure(); } expectedResultDim = std::max(operandDim, expectedResultDim); } // Check the dimension of the result is compatible with dimesion of the // operands if (resultShape[i] != expectedResultDim) { op->emitOpError() << "has the dimension #" << (maxOperandsDim - i) << " of the result incompatible with operands dimension" << ", got " << resultShape[i] << " expect " << expectedResultDim; return mlir::failure(); } } return mlir::success(); } LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op) { // Check operands type are ranked tensor llvm::SmallVector tensorOperands; unsigned i = 0; for (auto opType : op->getOperandTypes()) { auto tensorType = opType.dyn_cast_or_null(); if (tensorType == nullptr) { op->emitOpError() << " should have a ranked tensor as operand #" << i; return mlir::failure(); } tensorOperands.push_back(tensorType); i++; } // Check number of result is 1 if (op->getNumResults() != 1) { op->emitOpError() << "should have exactly 1 result, got " << op->getNumResults(); } auto tensorResult = op->getResult(0).getType().dyn_cast_or_null(); if (tensorResult == nullptr) { op->emitOpError(llvm::Twine("should have a ranked tensor as result")); return mlir::failure(); } return verifyTensorBroadcastingRules(op, tensorOperands, tensorResult); } LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op) { if (op->getNumOperands() != 2) { op->emitOpError() << "should have exactly 2 operands"; return mlir::failure(); } auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); if (op0Ty == nullptr || op1Ty == nullptr) { op->emitOpError() << "should have both operands as tensor"; return mlir::failure(); } auto el0Ty = op0Ty.getElementType() .dyn_cast_or_null(); if (el0Ty == nullptr) { op->emitOpError() << "should have a !FHE.eint as the element type of the " "tensor of operand #0"; return mlir::failure(); } auto el1Ty = op1Ty.getElementType().dyn_cast_or_null(); if (el1Ty == nullptr) { op->emitOpError() << "should have an integer as the element type of the " "tensor of operand #1"; return mlir::failure(); } if (el1Ty.getWidth() > el0Ty.getWidth() + 1) { op->emitOpError() << "should have the width of integer values less or equals " "than the width of encrypted values + 1"; return mlir::failure(); } return mlir::success(); } LogicalResult verifyTensorBinaryIntEint(mlir::Operation *op) { if (op->getNumOperands() != 2) { op->emitOpError() << "should have exactly 2 operands"; return mlir::failure(); } auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); if (op0Ty == nullptr || op1Ty == nullptr) { op->emitOpError() << "should have both operands as tensor"; return mlir::failure(); } auto el0Ty = op0Ty.getElementType().dyn_cast_or_null(); if (el0Ty == nullptr) { op->emitOpError() << "should have an integer as the element type of the " "tensor of operand #0"; return mlir::failure(); } auto el1Ty = op1Ty.getElementType() .dyn_cast_or_null(); if (el1Ty == nullptr) { op->emitOpError() << "should have a !FHE.eint as the element type of the " "tensor of operand #1"; return mlir::failure(); } if (el1Ty.getWidth() > el0Ty.getWidth() + 1) { op->emitOpError() << "should have the width of integer values less or equals " "than the width of encrypted values + 1"; return mlir::failure(); } return mlir::success(); } LogicalResult verifyTensorBinaryEint(mlir::Operation *op) { if (op->getNumOperands() != 2) { op->emitOpError() << "should have exactly 2 operands"; return mlir::failure(); } auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); if (op0Ty == nullptr || op1Ty == nullptr) { op->emitOpError() << "should have both operands as tensor"; return mlir::failure(); } auto el0Ty = op0Ty.getElementType() .dyn_cast_or_null(); if (el0Ty == nullptr) { op->emitOpError() << "should have a !FHE.eint as the element type of the " "tensor of operand #0"; return mlir::failure(); } auto el1Ty = op1Ty.getElementType() .dyn_cast_or_null(); if (el1Ty == nullptr) { op->emitOpError() << "should have a !FHE.eint as the element type of the " "tensor of operand #1"; return mlir::failure(); } if (el1Ty.getWidth() != el0Ty.getWidth()) { op->emitOpError() << "should have the width of encrypted equals" ", got " << el1Ty.getWidth() << " expect " << el0Ty.getWidth(); return mlir::failure(); } return mlir::success(); } LogicalResult verifyTensorUnaryEint(mlir::Operation *op) { if (op->getNumOperands() != 1) { op->emitOpError() << "should have exactly 1 operands"; return mlir::failure(); } auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); if (op0Ty == nullptr) { op->emitOpError() << "should have operand as tensor"; return mlir::failure(); } auto el0Ty = op0Ty.getElementType() .dyn_cast_or_null(); if (el0Ty == nullptr) { op->emitOpError() << "should have a !FHE.eint as the element type of the " "tensor operand"; return mlir::failure(); } return mlir::success(); } } // namespace impl } // namespace OpTrait } // namespace mlir namespace mlir { namespace concretelang { namespace FHELinalg { mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTableEintOp &op) { auto tTy = op.t().getType().cast(); auto tEltTy = tTy.getElementType() .cast(); auto lutTy = op.lut().getType().cast(); auto lutEltTy = lutTy.getElementType().cast(); auto resultTy = op.getResult().getType().cast(); // Check the shape of lut argument auto tEltwidth = tEltTy.getWidth(); mlir::SmallVector expectedShape{1 << tEltwidth}; if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isInteger(64)) { op.emitOpError() << "should have as operand #2 a tensor<2^pxi64>, where p is the width " "of the encrypted integer of the operand #1," << "expect tensor <" << expectedShape[0] << "xi64>"; return mlir::failure(); } if (!resultTy.hasStaticShape(tTy.getShape())) { op.emitOpError() << " should have same shapes for operand #1 and the result"; } return mlir::success(); } mlir::LogicalResult verifyApplyMultiLookupTable(ApplyMultiLookupTableEintOp &op) { auto tTy = op.t().getType().cast(); auto tEltTy = tTy.getElementType() .cast(); auto lutTy = op.luts().getType().cast(); auto lutEltTy = lutTy.getElementType().cast(); auto resultTy = op.getResult().getType().cast(); // Check the shape of luts argument auto lut_size = lutTy.getShape()[lutTy.getShape().size() - 1]; auto expected_lut_size = 1 << tEltTy.getWidth(); if (lut_size != expected_lut_size || !lutEltTy.isInteger(64)) { op.emitOpError() << "should have as operand #2 a " "tensor, where p is the width " "of the encrypted integer of the operand #1," << "expect tensor "; return mlir::failure(); } if (!resultTy.hasStaticShape(tTy.getShape())) { op.emitOpError() << " should have same shapes for operand #1 and the result"; } return mlir::success(); } mlir::RankedTensorType getTensorType(::mlir::Value value) { return value.getType().cast(); } template T getElmentType(::mlir::Value value) { auto tTy = getTensorType(value); return tTy.getElementType().cast(); } mlir::IntegerType getClearElmentType(::mlir::Value value) { return getElmentType(value); } FHE::EncryptedIntegerType getEncryptedElmentType(::mlir::Value value) { using namespace mlir::concretelang::FHE; return getElmentType(value); } mlir::LogicalResult verifyMapHasRightShape(ApplyMappedLookupTableEintOp &op, ::mlir::Value &lut_input, ::mlir::Value &lut_map) { auto input_shape = getTensorType(lut_input).getShape(); auto map_shape = getTensorType(lut_map).getShape(); if (input_shape.equals(map_shape)) { return mlir::success(); } std::string error; int input_rank = input_shape.size(); int map_rank = map_shape.size(); std::string input_name = "'t' (operand #1)"; std::string map_name = "'lut_map.getName()' (operand #3)"; if (input_rank == map_rank) { error = ": " + input_name + " dimensions differs from " + map_name; } else { error = ": " + input_name + " rank (=" + std::to_string(input_rank) + ") differs from " + map_name + " rank (=" + std::to_string(map_rank) + ")"; } op.emitOpError() << error; return mlir::failure(); } mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op, ::mlir::Value &encryptedIndex, ::mlir::Value &luts) { auto index_width = getEncryptedElmentType(encryptedIndex).getWidth(); auto actual_lut_size = getTensorType(luts).getShape().back(); auto expected_lut_size = 1 << index_width; if (actual_lut_size == expected_lut_size) { return mlir::success(); } FHE::emitErrorBadLutSize(op, "luts", "ct", expected_lut_size, index_width); return mlir::failure(); } mlir::LogicalResult verifyApplyMappedLookupTable(ApplyMappedLookupTableEintOp &op) { auto t = op.t(); auto luts = op.luts(); auto map = op.map(); auto result = op.getResult(); auto t_shape = getTensorType(t).getShape(); if (!getTensorType(result).hasStaticShape(t_shape)) { op.emitOpError() << ": `t` (operand #1) and `map` (operand #2) must have the same shape"; return mlir::failure(); } if (!getTensorType(map).getElementType().isIndex()) { op.emitOpError() << ": `map` (operand #3) should contains elements of type `index`"; return mlir::failure(); } return mlir::success(verifyMapHasRightShape(op, t, map).succeeded() && verifyLutsSize(op, t, luts).succeeded()); } ::mlir::LogicalResult verifyDotEintInt(Dot &op) { if (::mlir::failed(mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()))) { return op.emitOpError("arguments have incompatible shapes"); } auto lhsEltType = op.lhs() .getType() .cast() .getElementType() .cast(); auto rhsEltType = op.rhs() .getType() .cast() .getElementType() .cast(); auto resultType = op.getResult().getType().cast(); if (!mlir::concretelang::FHE:: verifyEncryptedIntegerAndIntegerInputsConsistency(op, lhsEltType, rhsEltType)) { return ::mlir::failure(); } if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(op, lhsEltType, resultType)) { return ::mlir::failure(); } return ::mlir::success(); } llvm::SmallVector verifySumCalculateActualOutputShape(mlir::Type outputType) { auto actualOutputShape = llvm::SmallVector{}; if (outputType.isa()) { auto outputTensorType = outputType.dyn_cast(); for (int64_t size : outputTensorType.getShape()) { actualOutputShape.push_back(size); } } return actualOutputShape; } llvm::SmallVector verifySumCalculateExpectedOutputShape( llvm::ArrayRef inputShape, int64_t inputDimensions, std::unordered_set &axesToDestroy, bool keepDims) { auto expectedOutputShape = llvm::SmallVector{}; for (int64_t i = 0; i < inputDimensions; i++) { bool ithAxisIsDestroyed = axesToDestroy.find(i) != axesToDestroy.end(); if (!ithAxisIsDestroyed) { expectedOutputShape.push_back(inputShape[i]); } else if (keepDims) { expectedOutputShape.push_back(1); } } return expectedOutputShape; } mlir::LogicalResult verifySum(SumOp &op) { mlir::Value input = op.getOperand(); mlir::Value output = op.getResult(); auto inputType = input.getType().dyn_cast(); mlir::Type outputType = output.getType(); FHE::EncryptedIntegerType inputElementType = inputType.getElementType().dyn_cast(); FHE::EncryptedIntegerType outputElementType = !outputType.isa() ? outputType.dyn_cast() : outputType.dyn_cast() .getElementType() .dyn_cast(); if (!FHE::verifyEncryptedIntegerInputAndResultConsistency( op, inputElementType, outputElementType)) { return mlir::failure(); } llvm::ArrayRef inputShape = inputType.getShape(); int64_t inputDimensions = (int64_t)inputShape.size(); mlir::ArrayAttr axes = op.axes(); bool keepDims = op.keep_dims(); auto axesToDestroy = std::unordered_set{}; for (mlir::Attribute axisAttribute : axes) { int64_t axis = axisAttribute.cast().getInt(); bool axisIsValid = (0 <= axis) && (axis < inputDimensions); if (!axisIsValid) { op.emitOpError("has invalid axes attribute"); return mlir::failure(); } axesToDestroy.insert(axis); } if (axesToDestroy.empty()) { for (int64_t i = 0; i < inputDimensions; i++) { axesToDestroy.insert(i); } } auto expectedOutputShape = verifySumCalculateExpectedOutputShape( inputShape, inputDimensions, axesToDestroy, keepDims); auto actualOutputShape = verifySumCalculateActualOutputShape(outputType); if (expectedOutputShape != actualOutputShape) { auto stream = op.emitOpError(); stream << "does not have the proper output shape of <"; if (!expectedOutputShape.empty()) { stream << expectedOutputShape[0]; for (size_t i = 1; i < expectedOutputShape.size(); i++) { stream << "x" << expectedOutputShape[i]; } } stream << ">"; return mlir::failure(); } return mlir::success(); } static bool sameShapeExceptAxis(llvm::ArrayRef shape1, llvm::ArrayRef shape2, size_t axis) { if (shape1.size() != shape2.size()) { return false; } for (size_t i = 0; i < shape1.size(); i++) { if (i != axis && shape1[i] != shape2[i]) { return false; } } return true; } mlir::LogicalResult verifyConcat(ConcatOp &op) { unsigned numOperands = op.getNumOperands(); if (numOperands < 2) { op->emitOpError() << "should have at least 2 inputs"; return mlir::failure(); } int64_t axis = op.axis(); mlir::Value out = op.out(); auto outVectorType = out.getType().dyn_cast(); auto outElementType = outVectorType.getElementType().dyn_cast(); llvm::ArrayRef outShape = outVectorType.getShape(); size_t outDims = outShape.size(); if (axis < 0 || (size_t)axis >= outDims) { op->emitOpError() << "has invalid axis attribute"; return mlir::failure(); } int64_t expectedOutputElementsInAxis = 0; size_t index = 0; for (mlir::Value in : op.ins()) { auto inVectorType = in.getType().dyn_cast(); auto inElementType = inVectorType.getElementType().dyn_cast(); if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(op, inElementType, outElementType)) { return ::mlir::failure(); } llvm::ArrayRef inShape = inVectorType.getShape(); if (!sameShapeExceptAxis(inShape, outShape, (size_t)axis)) { auto stream = op->emitOpError(); stream << "does not have the proper shape of <"; if (axis == 0) { stream << "?"; } else { stream << outShape[0]; } for (size_t i = 1; i < outDims; i++) { stream << "x"; if (i == (size_t)axis) { stream << "?"; } else { stream << outShape[i]; } } stream << "> for input #" << index; return mlir::failure(); } expectedOutputElementsInAxis += inShape[axis]; index += 1; } if (outShape[axis] != expectedOutputElementsInAxis) { auto stream = op->emitOpError(); stream << "does not have the proper output shape of <"; if (axis == 0) { stream << expectedOutputElementsInAxis; } else { stream << outShape[0]; } for (size_t i = 1; i < outDims; i++) { stream << "x"; if (i == (size_t)axis) { stream << expectedOutputElementsInAxis; } else { stream << outShape[i]; } } stream << ">"; return mlir::failure(); } return mlir::success(); } /// Verify the matmul shapes, the type of tensor elements should be checked by /// something else template mlir::LogicalResult verifyMatmul(MatMulOp &op) { auto lhsTy = ((mlir::Type)op.lhs().getType()).cast(); auto rhsTy = ((mlir::Type)op.rhs().getType()).cast(); auto resultTy = ((mlir::Type)op.getResult().getType()).cast(); if (lhsTy.getShape().size() != 2 || rhsTy.getShape().size() != 2) { op.emitOpError() << "should have 2D tensors as operands"; return mlir::failure(); } if (lhsTy.getDimSize(1) != rhsTy.getDimSize(0)) { op.emitOpError() << "should have the dimension #0 of operand #1" "equals to the dimension #1 of operand #0, expect " << lhsTy.getDimSize(1) << " got " << rhsTy.getDimSize(0); return mlir::failure(); } // Check the shape of lut argument mlir::SmallVector expectedShape{lhsTy.getDimSize(0), rhsTy.getDimSize(1)}; if (!resultTy.hasStaticShape(expectedShape)) { op.emitOpError() << "should have the result shape compatible with operands " << "shape, expect " << expectedShape[0] << "x" << expectedShape[1] << " as the shape of the result"; return mlir::failure(); } return mlir::success(); } mlir::SmallVector getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { mlir::SmallVector paddingInts; llvm::Optional optionalPadding = convOp.padding(); if (optionalPadding.hasValue()) { auto paddingAttr = optionalPadding.getValue(); auto paddingAttrShape = paddingAttr.getType().cast().getShape(); assert(paddingAttrShape.size() == 1 && paddingAttrShape[0] == 4 && "incorrect padding shape"); paddingInts.insert(paddingInts.begin(), paddingAttr.value_begin(), paddingAttr.value_end()); } else { paddingInts.insert(paddingInts.begin(), {0, 0, 0, 0}); } return paddingInts; } mlir::SmallVector getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { mlir::SmallVector stridesInts; llvm::Optional optionalStrides = convOp.strides(); if (optionalStrides.hasValue()) { auto stridesAttr = optionalStrides.getValue(); auto stridesAttrShape = stridesAttr.getType().cast().getShape(); assert(stridesAttrShape.size() == 1 && stridesAttrShape[0] == 2 && "incorrect strides shape"); stridesInts.insert(stridesInts.begin(), stridesAttr.value_begin(), stridesAttr.value_end()); } else { stridesInts.insert(stridesInts.begin(), {1, 1}); } return stridesInts; } mlir::SmallVector getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { mlir::SmallVector dilationsInts; llvm::Optional optionalDilations = convOp.dilations(); if (optionalDilations.hasValue()) { auto dilationsAttr = optionalDilations.getValue(); auto dilationsAttrShape = dilationsAttr.getType().cast().getShape(); assert(dilationsAttrShape.size() == 1 && dilationsAttrShape[0] == 2 && "incorrect dilations shape"); dilationsInts.insert(dilationsInts.begin(), dilationsAttr.value_begin(), dilationsAttr.value_end()); } else { dilationsInts.insert(dilationsInts.begin(), {1, 1}); } return dilationsInts; } /// Verify the Conv2d shapes, attributes, and expected output dimensions mlir::LogicalResult verifyConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { auto inputTy = ((mlir::Type)convOp.input().getType()).cast(); auto weightTy = ((mlir::Type)convOp.weight().getType()).cast(); auto resultTy = ((mlir::Type)convOp.getResult().getType()).cast(); auto inputShape = inputTy.getShape(); auto weightShape = weightTy.getShape(); auto resultShape = resultTy.getShape(); auto p = inputTy.getElementType() .cast() .getWidth(); auto weightElementTyWidth = weightTy.getElementType().cast().getWidth(); if (weightElementTyWidth != p + 1) { convOp.emitOpError() << "expected weight element type to have width " << p + 1 << " but got " << weightElementTyWidth; return mlir::failure(); } // Checking dimensions if (inputShape.size() != 4) { convOp.emitOpError() << "input should have 4 dimensions (N*C*H*W) but got " << inputShape.size(); return mlir::failure(); } if (weightShape.size() != 4) { convOp.emitOpError() << "weight should have 4 dimensions (F*C*H*W) but got " << weightShape.size(); return mlir::failure(); } if (resultShape.size() != 4) { convOp.emitOpError() << "result should have 4 dimensions (N*C*H*W) but got " << resultShape.size(); return mlir::failure(); } // Checking attributes mlir::SmallVector paddingInts = getPaddingFromConv2d(convOp); llvm::Optional optionalPadding = convOp.padding(); if (optionalPadding.hasValue()) { auto paddingAttr = optionalPadding.getValue(); auto paddingAttrShape = paddingAttr.getType().cast().getShape(); if (paddingAttrShape.size() != 1 || paddingAttrShape[0] != 4) { convOp.emitOpError() << "padding should have a single dimension of size 4, but got shape [" << paddingAttrShape << "]"; return mlir::failure(); } for (auto i = 0; i < 4; i++) { // TODO: Support padding (#427) if (paddingInts[i] != 0) { convOp.emitOpError() << "padding isn't yet supported, but got a non zero value (" << paddingInts[i] << ") at index " << i; return mlir::failure(); } if (paddingInts[i] < 0) { convOp.emitOpError() << "padding can't have a negative value, but got " << paddingInts[i] << " at index " << i; return mlir::failure(); } } } mlir::SmallVector stridesInts = getStridesFromConv2d(convOp); llvm::Optional optionalStrides = convOp.strides(); if (optionalStrides.hasValue()) { auto stridesAttr = optionalStrides.getValue(); auto stridesAttrShape = stridesAttr.getType().cast().getShape(); if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) { convOp.emitOpError() << "strides should have a single dimension of size 2, but got shape [" << stridesAttrShape << "]"; return mlir::failure(); } for (auto i = 0; i < 2; i++) { if (stridesInts[i] < 1) { convOp.emitOpError() << "strides can't have a value less than 1, but got " << stridesInts[i] << " at index " << i; return mlir::failure(); } } } mlir::SmallVector dilationsInts = getDilationsFromConv2d(convOp); llvm::Optional optionalDilations = convOp.dilations(); if (optionalDilations.hasValue()) { auto dilationsAttr = optionalDilations.getValue(); auto dilationsAttrShape = dilationsAttr.getType().cast().getShape(); if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) { convOp.emitOpError() << "dilations should have a single dimension of " "size 2, but got shape [" << dilationsAttrShape << "]"; return mlir::failure(); } for (auto i = 0; i < 2; i++) { if (dilationsInts[i] < 1) { convOp.emitOpError() << "dilations can't have a value less than 1, but got " << dilationsInts[i] << " at index " << i; return mlir::failure(); } } } // Extracting dimensions int64_t inputN = inputShape[0], inputC = inputShape[1], inputH = inputShape[2], inputW = inputShape[3]; int64_t weightF = weightShape[0], weightC = weightShape[1], weightH = weightShape[2], weightW = weightShape[3]; int64_t resultN = resultShape[0], resultC = resultShape[1], resultH = resultShape[2], resultW = resultShape[3]; // Bias check if specified mlir::Value bias = convOp.bias(); if (bias) { auto biasTy = ((mlir::Type)bias.getType()).cast(); auto biasShape = biasTy.getShape(); if (biasShape.size() != 1) { convOp.emitOpError() << "bias should have 1 dimension but got " << biasShape.size(); return mlir::failure(); } if (biasShape[0] != weightF) { convOp.emitOpError() << "expected bias vector to have size " << weightF << " but got " << biasShape[0]; return mlir::failure(); } auto biasElementTyWidth = biasTy.getElementType().cast().getWidth(); if (biasElementTyWidth != p + 1) { convOp.emitOpError() << "expected bias element type to have width " << p + 1 << " but got " << biasElementTyWidth; return mlir::failure(); } } // Dimension sizes checks if (resultN != inputN) { convOp.emitOpError() << "expected result batch size to be equal to input batch size (" << inputN << ") but got " << resultN; return mlir::failure(); } if (inputC != weightC) { convOp.emitOpError() << "expected number of channels in weight to be equal " "to number of channels in input (" << inputC << ") but got " << weightC; return mlir::failure(); } if (weightF != resultC) { convOp.emitOpError() << "expected number of output channels to be equal to " "the number of filters (" << weightF << ") but got " << resultC; return mlir::failure(); } int64_t paddingH = paddingInts[0] + paddingInts[2]; int64_t paddingW = paddingInts[1] + paddingInts[3]; int64_t dilationH = dilationsInts[0]; int64_t dilationW = dilationsInts[1]; int64_t strideH = stridesInts[0]; int64_t strideW = stridesInts[1]; int64_t expectedResultH = floor((inputH + paddingH - dilationH * (weightH - 1) - 1) / strideH) + 1; int64_t expectedResultW = floor((inputW + paddingW - dilationW * (weightW - 1) - 1) / strideW) + 1; if (expectedResultH != resultH) { convOp.emitOpError() << "expected height of output to be equal to " << expectedResultH << " but got " << resultH; return mlir::failure(); } if (expectedResultW != resultW) { convOp.emitOpError() << "expected width of output to be equal to " << expectedResultW << " but got " << resultW; return mlir::failure(); } return mlir::success(); } //===----------------------------------------------------------------------===// // Implementation of FhelinalgConv2DNchwFchwOp // This is a generated functions from `make generate_conv_op`, and some helpers // from LinalgOps.cpp //===----------------------------------------------------------------------===// using namespace mlir; using namespace mlir::linalg; ArrayAttr FhelinalgConv2DNchwFchwOp::iterator_types() { return Builder(getContext()) .getStrArrayAttr(SmallVector{ getParallelIteratorTypeName(), getParallelIteratorTypeName(), getParallelIteratorTypeName(), getParallelIteratorTypeName(), getReductionIteratorTypeName(), getReductionIteratorTypeName(), getReductionIteratorTypeName()}); } static SmallVector getSymbolBindings(FhelinalgConv2DNchwFchwOp self) { MLIRContext *context = self.getContext(); SmallVector exprs; exprs.push_back(getAffineSymbolExpr(0, context)); exprs.push_back(getAffineSymbolExpr(1, context)); exprs.push_back(getAffineSymbolExpr(2, context)); int64_t cst3 = self.strides().getValue({0}); exprs.push_back(getAffineConstantExpr(cst3, context)); exprs.push_back(getAffineSymbolExpr(4, context)); int64_t cst5 = self.dilations().getValue({0}); exprs.push_back(getAffineConstantExpr(cst5, context)); exprs.push_back(getAffineSymbolExpr(6, context)); int64_t cst7 = self.strides().getValue({1}); exprs.push_back(getAffineConstantExpr(cst7, context)); exprs.push_back(getAffineSymbolExpr(8, context)); int64_t cst9 = self.dilations().getValue({1}); exprs.push_back(getAffineConstantExpr(cst9, context)); exprs.push_back(getAffineSymbolExpr(10, context)); return exprs; } ArrayAttr FhelinalgConv2DNchwFchwOp::indexing_maps() { static const char memoizeAttr[] = "linalg.memoized_indexing_maps"; ArrayAttr cached = getOperation()->getAttrOfType(memoizeAttr); if (cached) return cached; MLIRContext *context = getContext(); auto symbolBindings = getSymbolBindings(*this); SmallVector maps; maps.push_back( mlir::parseAttribute( "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, " "s7, s8, s9, s10] -> (d0, d4, d2 * s3 + d5 * s5, d3 * s7 + d6 * s9)>", context) .cast() .getValue()); maps.back() = simplifyAffineMap( maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); maps.push_back(mlir::parseAttribute( "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, " "s4, s5, s6, s7, s8, s9, s10] -> (d1, d4, d5, d6)>", context) .cast() .getValue()); maps.back() = simplifyAffineMap( maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); maps.push_back(mlir::parseAttribute( "affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, " "s4, s5, s6, s7, s8, s9, s10] -> (d0, d1, d2, d3)>", context) .cast() .getValue()); maps.back() = simplifyAffineMap( maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0)); cached = Builder(context).getAffineMapArrayAttr(maps); getOperation()->setAttr(memoizeAttr, cached); return cached; } unsigned FhelinalgConv2DNchwFchwOp::getNumRegionArgs() { return 3; } std::string FhelinalgConv2DNchwFchwOp::getLibraryCallName() { return generateLibraryCallName(getOperation()); } bool FhelinalgConv2DNchwFchwOp::hasDynamicIndexingMaps() { return true; } LogicalResult FhelinalgConv2DNchwFchwOp::verifyIndexingMapRequiredAttributes() { Operation *op = getOperation(); if (auto attr = op->getAttrOfType("strides")) { if (!attr.getType().getElementType().isInteger(64)) return op->emitError("incorrect element type for indexing map required " "attribute 'strides'"); if (attr.getType().getShape() != ArrayRef{2}) return op->emitError( "incorrect shape for indexing map required attribute 'strides'"); } else { return op->emitError("missing indexing map required attribute 'strides'"); } if (auto attr = op->getAttrOfType("dilations")) { if (!attr.getType().getElementType().isInteger(64)) return op->emitError("incorrect element type for indexing map required " "attribute 'dilations'"); if (attr.getType().getShape() != ArrayRef{2}) return op->emitError( "incorrect shape for indexing map required attribute 'dilations'"); } else { return op->emitError("missing indexing map required attribute 'dilations'"); } return success(); } /// Some helpers were copied from LinalgOps.cpp /// Generic entry point to create the block for the region of a LinalgOp. /// This is used by both named structured ops created by ods-gen and by manually /// defined C++ ops. /// This is used by both builders and parsers. /// This function creates the block in the region with arguments corresponding /// to the elemental types of `inputTypes` and `outputTypes`. The latter are /// asserted to be of ShapedType. template static void fillStructuredOpRegion( OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, std::function errorHandler = nullptr); /// Generic entry point to create both the region and the block of a LinalgOp. template static void createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes, TypeRange outputTypes); /// Common parsing and printing used for both named structured ops created by /// ods-gen and by manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes); template static void printCommonStructuredOpParts(OpAsmPrinter &p, NamedStructuredOpType op); /// Specific parsing and printing for named structured ops created by ods-gen. template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, TypeRange inputTypes, TypeRange outputTypes); static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes); template static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result); static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes); template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); class RegionBuilderHelper { public: RegionBuilderHelper(MLIRContext *context, Block &block) : context(context), block(block) {} // Generates operations to cast the given operand to a specified type. // If the cast cannot be performed, a warning will be issued and the // operand returned as-is (which will presumably yield a verification // issue downstream). Value cast(Type toType, Value operand, bool isUnsignedCast) { OpBuilder builder = getBuilder(); auto loc = operand.getLoc(); if (operand.getType() == toType) return operand; if (auto toIntType = toType.dyn_cast()) { // If operand is floating point, cast directly to the int type. if (operand.getType().isa()) { if (isUnsignedCast) return builder.create(loc, toType, operand); return builder.create(loc, toType, operand); } // Cast index operands directly to the int type. if (operand.getType().isIndex()) return builder.create(loc, toType, operand); if (auto fromIntType = operand.getType().dyn_cast()) { // Either extend or truncate. if (toIntType.getWidth() > fromIntType.getWidth()) { if (isUnsignedCast) return builder.create(loc, toType, operand); return builder.create(loc, toType, operand); } if (toIntType.getWidth() < fromIntType.getWidth()) return builder.create(loc, toType, operand); } } else if (auto toFloatType = toType.dyn_cast()) { // If operand is integer, cast directly to the float type. // Note that it is unclear how to cast from BF16<->FP16. if (operand.getType().isa()) { if (isUnsignedCast) return builder.create(loc, toFloatType, operand); return builder.create(loc, toFloatType, operand); } if (auto fromFloatType = operand.getType().dyn_cast()) { if (toFloatType.getWidth() > fromFloatType.getWidth()) return builder.create(loc, toFloatType, operand); if (toFloatType.getWidth() < fromFloatType.getWidth()) return builder.create(loc, toFloatType, operand); } } emitWarning(operand.getLoc()) << "could not cast operand of type " << operand.getType() << " to " << toType; return operand; } Value applyfn__add(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs) && rhs.getType().isa()) { return builder.create(lhs.getLoc(), rhs, lhs); } if (lhs.getType().isa() && isInteger(rhs)) { return builder.create(lhs.getLoc(), lhs, rhs); } if (lhs.getType().isa() && rhs.getType().isa()) { return builder.create(lhs.getLoc(), lhs, rhs); } llvm_unreachable("unsupported non numeric type"); } Value applyfn__exp(Value x) { OpBuilder builder = getBuilder(); if (isFloatingPoint(x)) return builder.create(x.getLoc(), x); llvm_unreachable("unsupported non numeric type"); } Value applyfn__log(Value x) { OpBuilder builder = getBuilder(); if (isFloatingPoint(x)) return builder.create(x.getLoc(), x); llvm_unreachable("unsupported non numeric type"); } Value applyfn__sub(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__mul(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (lhs.getType().isa() && isInteger(rhs)) { return builder.create(lhs.getLoc(), lhs, rhs); } llvm_unreachable("unsupported non numeric type"); } Value applyfn__max(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__max_unsigned(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__min(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__min_unsigned(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } void yieldOutputs(ValueRange values) { assert(!values.empty() && "linalg ops must yield outputs"); if (values.empty()) return; Value first = values.front(); OpBuilder builder = getBuilder(); builder.create(first.getLoc(), values); } Value constant(std::string value) { OpBuilder builder = getBuilder(); Location loc = builder.getUnknownLoc(); Attribute valueAttr = parseAttribute(value, builder.getContext()); return builder.create(loc, valueAttr.getType(), valueAttr); } Value index(int64_t dim) { OpBuilder builder = getBuilder(); return builder.create(builder.getUnknownLoc(), dim); } Type getIntegerType(unsigned width) { return IntegerType::get(context, width); } Type getFloat32Type() { return Float32Type::get(context); } Type getFloat64Type() { return Float64Type::get(context); } private: MLIRContext *context; Block █ bool isFloatingPoint(Value value) { return value.getType().isa(); } bool isInteger(Value value) { return value.getType().isa(); } OpBuilder getBuilder() { OpBuilder builder(context); builder.setInsertionPointToEnd(&block); return builder; } }; static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto castOp = operand.get().getDefiningOp(); if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); folded = true; } } return success(folded); } /// Generic entry point to create the block for the region of a LinalgOp. /// This is used by both named structured ops created by ods-gen and by manually /// defined C++ ops. /// This is used by both builders and parsers. /// This function creates the block in the region with arguments corresponding /// to the elemental types of `inputTypes` and `outputTypes`, which are asserted /// to be ShapedType. template static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, std::function errorHandler) { assert(llvm::all_of(outputTypes, [](Type t) { return t.isa(); })); // TODO: atm all operands go through getElementTypeOrSelf, // reconsider when we have evidence we need to. SmallVector argTypes; for (auto containers : {inputTypes, outputTypes}) for (auto t : containers) argTypes.push_back(getElementTypeOrSelf(t)); // RAII. OpBuilder::InsertionGuard guard(opBuilder); Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes); unsigned actual = body->getNumArguments(); unsigned expected = NamedStructuredOpType::getNumRegionArgs(); if (expected != actual) { if (errorHandler) errorHandler(expected, actual); return; } opBuilder.setInsertionPointToStart(body); ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); NamedStructuredOpType::regionBuilder(b, *body); // indexing_maps is an auto-generated method. // iterator_types is an auto-generated method. } static void getGenericEffectsImpl( SmallVectorImpl> &effects, ValueRange results, ValueRange inputBuffers, ValueRange outputs) { for (Value value : results) { effects.emplace_back(MemoryEffects::Allocate::get(), value, SideEffects::DefaultResource::get()); } for (Value value : inputBuffers) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); } for (Value value : outputs) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), value, SideEffects::DefaultResource::get()); } } /// Generic entry point to create both the region and the block of a LinalgOp. template void createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes, TypeRange outputTypes) { Region ®ion = *result.addRegion(); fillStructuredOpRegion( opBuilder, region, inputTypes, outputTypes, [&](unsigned expected, unsigned actual) { assert(expected != actual && "incorrect number of arguments"); }); } static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes) { if (resultTypes.empty()) return; p.printOptionalArrowTypeList(resultTypes); } template static void printCommonStructuredOpParts(OpAsmPrinter &p, NamedStructuredOpType op) { if (!op.inputs().empty()) p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; if (!op.outputs().empty()) p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")"; } template static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { p.printOptionalAttrDict( op->getAttrs(), /*elidedAttrs=*/{"operand_segment_sizes", // See generated code in mlir-linalg-yaml-gen.cpp "linalg.memoized_indexing_maps"}); // Printing is shared with generic ops, except for the region and // attributes. printCommonStructuredOpParts(p, op); // Results printing. printNamedStructuredOpResults(p, op.result_tensors().getTypes()); // Region is elided. } /// Common parsing used for both named structured ops created by ods-gen and by /// manually defined C++ ops. Does not handle regions. static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl &inputTypes, SmallVectorImpl &outputTypes) { llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc; SmallVector inputsOperands, outputsOperands; parser.parseOptionalAttrDict(result.attributes); if (succeeded(parser.parseOptionalKeyword("ins"))) { if (parser.parseLParen()) return failure(); inputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseOperandList(inputsOperands) || parser.parseColonTypeList(inputTypes) || parser.parseRParen()) return failure(); } if (succeeded(parser.parseOptionalKeyword("outs"))) { outputsOperandsLoc = parser.getCurrentLocation(); if (parser.parseLParen() || parser.parseOperandList(outputsOperands) || parser.parseColonTypeList(outputTypes) || parser.parseRParen()) return failure(); } if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, result.operands) || parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc, result.operands)) return failure(); result.addAttribute("operand_segment_sizes", parser.getBuilder().getI32VectorAttr( {static_cast(inputsOperands.size()), static_cast(outputsOperands.size())})); return success(); } template static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, TypeRange inputTypes, TypeRange outputTypes) { ParseResult res = success(); OpBuilder opBuilder(parser.getContext()); // Resolve `captures` into `capturedValues` at parse time so we can build the // region with captures. SmallVector capturedValues; fillStructuredOpRegion( opBuilder, region, inputTypes, outputTypes, [&](unsigned expected, unsigned actual) { res = parser.emitError( parser.getCurrentLocation(), llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated " "region expects {0} args, got {1}", expected, actual)); region.front().dump(); }); return res; } static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes) { if (parser.parseOptionalArrowTypeList(resultTypes)) return failure(); return success(); } template static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result) { // TODO: Enable when ods-gen supports captures. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) return failure(); // TODO: consider merging results parsing into region parsing. // Need to wait for declarative assembly resolution to decide. SmallVector outputTensorsTypes; if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) return failure(); result.addTypes(outputTensorsTypes); std::unique_ptr region = std::make_unique(); if (parseNamedStructuredOpRegion( parser, *region, inputTypes, outputTypes)) return failure(); result.addRegion(std::move(region)); return success(); } /// END OF COPY FROM LinalgOps.cpp void FhelinalgConv2DNchwFchwOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { assert(3 > 0 && block.getNumArguments() == 3 && "FhelinalgConv2DNchwFchwOp regionBuilder expects 3 (>=0) args"); RegionBuilderHelper helper(block.getArgument(0).getContext(), block); SmallVector yields; Value value1 = helper.cast(block.getArgument(0).getType(), block.getArgument(0), false); Value value2 = helper.cast(block.getArgument(1).getType(), block.getArgument(1), false); Value value3 = helper.applyfn__mul(value1, value2); Value value4 = helper.applyfn__add(block.getArgument(2), value3); yields.push_back(value4); helper.yieldOutputs(yields); } LogicalResult FhelinalgConv2DNchwFchwOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCast(*this); } void FhelinalgConv2DNchwFchwOp::getEffects( SmallVectorImpl> &effects) { SmallVector inputBuffers = getInputBufferOperands(); SmallVector outputBuffers = getOutputBufferOperands(); getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, outputBuffers); } } // namespace FHELinalg } // namespace concretelang } // namespace mlir #define GET_OP_CLASSES #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.cpp.inc"