mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
1650 lines
60 KiB
C++
1650 lines
60 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
|
// for license information.
|
|
|
|
#include <unordered_set>
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Parser.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
|
|
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
|
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
|
|
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
|
|
|
|
namespace mlir {
|
|
namespace OpTrait {
|
|
namespace impl {
|
|
|
|
LogicalResult verifyTensorBroadcastingRules(
|
|
mlir::Operation *op, llvm::SmallVector<mlir::RankedTensorType> operands,
|
|
mlir::RankedTensorType result) {
|
|
llvm::SmallVector<llvm::ArrayRef<int64_t>> operandsShapes;
|
|
size_t maxOperandsDim = 0;
|
|
auto resultShape = result.getShape();
|
|
for (size_t i = 0; i < operands.size(); i++) {
|
|
auto shape = operands[i].getShape();
|
|
operandsShapes.push_back(shape);
|
|
maxOperandsDim = std::max(shape.size(), maxOperandsDim);
|
|
}
|
|
// Check the result has the same number of dimension than the highest
|
|
// dimension of operands
|
|
if (resultShape.size() != maxOperandsDim) {
|
|
op->emitOpError()
|
|
<< "should have the number of dimensions of the result equal to the "
|
|
"highest number of dimensions of operands"
|
|
<< ", got " << result.getShape().size() << " expect " << maxOperandsDim;
|
|
return mlir::failure();
|
|
}
|
|
|
|
// For all dimension
|
|
for (size_t i = 0; i < maxOperandsDim; i++) {
|
|
int64_t expectedResultDim = 1;
|
|
|
|
// Check the dimension of operands shape are compatible, i.e. equals or 1
|
|
for (size_t j = 0; j < operandsShapes.size(); j++) {
|
|
if (i < maxOperandsDim - operandsShapes[j].size()) {
|
|
continue;
|
|
}
|
|
auto k = i - (maxOperandsDim - operandsShapes[j].size());
|
|
auto operandDim = operandsShapes[j][k];
|
|
if (expectedResultDim != 1 && operandDim != 1 &&
|
|
operandDim != expectedResultDim) {
|
|
op->emitOpError() << "has the dimension #"
|
|
<< (operandsShapes[j].size() - k)
|
|
<< " of the operand #" << j
|
|
<< " incompatible with other operands"
|
|
<< ", got " << operandDim << " expect 1 or "
|
|
<< expectedResultDim;
|
|
return mlir::failure();
|
|
}
|
|
|
|
expectedResultDim = std::max(operandDim, expectedResultDim);
|
|
}
|
|
|
|
// Check the dimension of the result is compatible with dimesion of the
|
|
// operands
|
|
if (resultShape[i] != expectedResultDim) {
|
|
op->emitOpError() << "has the dimension #" << (maxOperandsDim - i)
|
|
<< " of the result incompatible with operands dimension"
|
|
<< ", got " << resultShape[i] << " expect "
|
|
<< expectedResultDim;
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
|
|
return mlir::success();
|
|
}
|
|
|
|
LogicalResult verifyTensorBroadcastingRules(mlir::Operation *op) {
|
|
// Check operands type are ranked tensor
|
|
llvm::SmallVector<mlir::RankedTensorType> tensorOperands;
|
|
unsigned i = 0;
|
|
for (auto opType : op->getOperandTypes()) {
|
|
auto tensorType = opType.dyn_cast_or_null<mlir::RankedTensorType>();
|
|
if (tensorType == nullptr) {
|
|
op->emitOpError() << " should have a ranked tensor as operand #" << i;
|
|
return mlir::failure();
|
|
}
|
|
tensorOperands.push_back(tensorType);
|
|
i++;
|
|
}
|
|
// Check number of result is 1
|
|
if (op->getNumResults() != 1) {
|
|
op->emitOpError() << "should have exactly 1 result, got "
|
|
<< op->getNumResults();
|
|
}
|
|
auto tensorResult =
|
|
op->getResult(0).getType().dyn_cast_or_null<mlir::RankedTensorType>();
|
|
if (tensorResult == nullptr) {
|
|
op->emitOpError(llvm::Twine("should have a ranked tensor as result"));
|
|
return mlir::failure();
|
|
}
|
|
return verifyTensorBroadcastingRules(op, tensorOperands, tensorResult);
|
|
}
|
|
|
|
LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op) {
|
|
if (op->getNumOperands() != 2) {
|
|
op->emitOpError() << "should have exactly 2 operands";
|
|
return mlir::failure();
|
|
}
|
|
auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
|
|
auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null<mlir::TensorType>();
|
|
if (op0Ty == nullptr || op1Ty == nullptr) {
|
|
op->emitOpError() << "should have both operands as tensor";
|
|
return mlir::failure();
|
|
}
|
|
auto el0Ty =
|
|
op0Ty.getElementType()
|
|
.dyn_cast_or_null<mlir::concretelang::FHE::EncryptedIntegerType>();
|
|
if (el0Ty == nullptr) {
|
|
op->emitOpError() << "should have a !FHE.eint as the element type of the "
|
|
"tensor of operand #0";
|
|
return mlir::failure();
|
|
}
|
|
auto el1Ty = op1Ty.getElementType().dyn_cast_or_null<mlir::IntegerType>();
|
|
if (el1Ty == nullptr) {
|
|
op->emitOpError() << "should have an integer as the element type of the "
|
|
"tensor of operand #1";
|
|
return mlir::failure();
|
|
}
|
|
if (el1Ty.getWidth() > el0Ty.getWidth() + 1) {
|
|
op->emitOpError()
|
|
<< "should have the width of integer values less or equals "
|
|
"than the width of encrypted values + 1";
|
|
return mlir::failure();
|
|
}
|
|
return mlir::success();
|
|
}
|
|
|
|
LogicalResult verifyTensorBinaryIntEint(mlir::Operation *op) {
|
|
if (op->getNumOperands() != 2) {
|
|
op->emitOpError() << "should have exactly 2 operands";
|
|
return mlir::failure();
|
|
}
|
|
auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
|
|
auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null<mlir::TensorType>();
|
|
if (op0Ty == nullptr || op1Ty == nullptr) {
|
|
op->emitOpError() << "should have both operands as tensor";
|
|
return mlir::failure();
|
|
}
|
|
auto el0Ty = op0Ty.getElementType().dyn_cast_or_null<mlir::IntegerType>();
|
|
if (el0Ty == nullptr) {
|
|
op->emitOpError() << "should have an integer as the element type of the "
|
|
"tensor of operand #0";
|
|
return mlir::failure();
|
|
}
|
|
auto el1Ty =
|
|
op1Ty.getElementType()
|
|
.dyn_cast_or_null<mlir::concretelang::FHE::EncryptedIntegerType>();
|
|
if (el1Ty == nullptr) {
|
|
op->emitOpError() << "should have a !FHE.eint as the element type of the "
|
|
"tensor of operand #1";
|
|
return mlir::failure();
|
|
}
|
|
if (el1Ty.getWidth() > el0Ty.getWidth() + 1) {
|
|
op->emitOpError()
|
|
<< "should have the width of integer values less or equals "
|
|
"than the width of encrypted values + 1";
|
|
return mlir::failure();
|
|
}
|
|
return mlir::success();
|
|
}
|
|
|
|
LogicalResult verifyTensorBinaryEint(mlir::Operation *op) {
|
|
if (op->getNumOperands() != 2) {
|
|
op->emitOpError() << "should have exactly 2 operands";
|
|
return mlir::failure();
|
|
}
|
|
auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
|
|
auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null<mlir::TensorType>();
|
|
if (op0Ty == nullptr || op1Ty == nullptr) {
|
|
op->emitOpError() << "should have both operands as tensor";
|
|
return mlir::failure();
|
|
}
|
|
auto el0Ty =
|
|
op0Ty.getElementType()
|
|
.dyn_cast_or_null<mlir::concretelang::FHE::EncryptedIntegerType>();
|
|
if (el0Ty == nullptr) {
|
|
op->emitOpError() << "should have a !FHE.eint as the element type of the "
|
|
"tensor of operand #0";
|
|
return mlir::failure();
|
|
}
|
|
auto el1Ty =
|
|
op1Ty.getElementType()
|
|
.dyn_cast_or_null<mlir::concretelang::FHE::EncryptedIntegerType>();
|
|
if (el1Ty == nullptr) {
|
|
op->emitOpError() << "should have a !FHE.eint as the element type of the "
|
|
"tensor of operand #1";
|
|
return mlir::failure();
|
|
}
|
|
if (el1Ty.getWidth() != el0Ty.getWidth()) {
|
|
op->emitOpError() << "should have the width of encrypted equals"
|
|
", got "
|
|
<< el1Ty.getWidth() << " expect " << el0Ty.getWidth();
|
|
return mlir::failure();
|
|
}
|
|
return mlir::success();
|
|
}
|
|
|
|
LogicalResult verifyTensorUnaryEint(mlir::Operation *op) {
|
|
if (op->getNumOperands() != 1) {
|
|
op->emitOpError() << "should have exactly 1 operands";
|
|
return mlir::failure();
|
|
}
|
|
auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
|
|
if (op0Ty == nullptr) {
|
|
op->emitOpError() << "should have operand as tensor";
|
|
return mlir::failure();
|
|
}
|
|
auto el0Ty =
|
|
op0Ty.getElementType()
|
|
.dyn_cast_or_null<mlir::concretelang::FHE::EncryptedIntegerType>();
|
|
if (el0Ty == nullptr) {
|
|
op->emitOpError() << "should have a !FHE.eint as the element type of the "
|
|
"tensor operand";
|
|
return mlir::failure();
|
|
}
|
|
return mlir::success();
|
|
}
|
|
|
|
} // namespace impl
|
|
|
|
} // namespace OpTrait
|
|
} // namespace mlir
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
namespace FHELinalg {
|
|
|
|
mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTableEintOp &op) {
|
|
auto tTy = op.t().getType().cast<mlir::RankedTensorType>();
|
|
auto tEltTy = tTy.getElementType()
|
|
.cast<mlir::concretelang::FHE::EncryptedIntegerType>();
|
|
auto lutTy = op.lut().getType().cast<mlir::RankedTensorType>();
|
|
auto lutEltTy = lutTy.getElementType().cast<mlir::IntegerType>();
|
|
auto resultTy = op.getResult().getType().cast<mlir::RankedTensorType>();
|
|
|
|
// Check the shape of lut argument
|
|
auto tEltwidth = tEltTy.getWidth();
|
|
mlir::SmallVector<int64_t, 1> expectedShape{1 << tEltwidth};
|
|
if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isInteger(64)) {
|
|
op.emitOpError()
|
|
<< "should have as operand #2 a tensor<2^pxi64>, where p is the width "
|
|
"of the encrypted integer of the operand #1,"
|
|
<< "expect tensor <" << expectedShape[0] << "xi64>";
|
|
return mlir::failure();
|
|
}
|
|
if (!resultTy.hasStaticShape(tTy.getShape())) {
|
|
op.emitOpError()
|
|
<< " should have same shapes for operand #1 and the result";
|
|
}
|
|
return mlir::success();
|
|
}
|
|
|
|
mlir::LogicalResult
|
|
verifyApplyMultiLookupTable(ApplyMultiLookupTableEintOp &op) {
|
|
auto tTy = op.t().getType().cast<mlir::RankedTensorType>();
|
|
auto tEltTy = tTy.getElementType()
|
|
.cast<mlir::concretelang::FHE::EncryptedIntegerType>();
|
|
auto lutTy = op.luts().getType().cast<mlir::RankedTensorType>();
|
|
auto lutEltTy = lutTy.getElementType().cast<mlir::IntegerType>();
|
|
auto resultTy = op.getResult().getType().cast<mlir::RankedTensorType>();
|
|
|
|
// Check the shape of luts argument
|
|
auto lut_size = lutTy.getShape()[lutTy.getShape().size() - 1];
|
|
auto expected_lut_size = 1 << tEltTy.getWidth();
|
|
if (lut_size != expected_lut_size || !lutEltTy.isInteger(64)) {
|
|
op.emitOpError() << "should have as operand #2 a "
|
|
"tensor<DMx...xD1X2^pxi64>, where p is the width "
|
|
"of the encrypted integer of the operand #1,"
|
|
<< "expect tensor <DMx...xD1X" << expected_lut_size
|
|
<< "xi64>";
|
|
return mlir::failure();
|
|
}
|
|
if (!resultTy.hasStaticShape(tTy.getShape())) {
|
|
op.emitOpError()
|
|
<< " should have same shapes for operand #1 and the result";
|
|
}
|
|
return mlir::success();
|
|
}
|
|
|
|
mlir::RankedTensorType getTensorType(::mlir::Value value) {
|
|
return value.getType().cast<mlir::RankedTensorType>();
|
|
}
|
|
|
|
template <class T> T getElmentType(::mlir::Value value) {
|
|
auto tTy = getTensorType(value);
|
|
return tTy.getElementType().cast<T>();
|
|
}
|
|
|
|
mlir::IntegerType getClearElmentType(::mlir::Value value) {
|
|
return getElmentType<mlir::IntegerType>(value);
|
|
}
|
|
|
|
FHE::EncryptedIntegerType getEncryptedElmentType(::mlir::Value value) {
|
|
using namespace mlir::concretelang::FHE;
|
|
return getElmentType<FHE::EncryptedIntegerType>(value);
|
|
}
|
|
|
|
mlir::LogicalResult verifyMapHasRightShape(ApplyMappedLookupTableEintOp &op,
|
|
::mlir::Value &lut_input,
|
|
::mlir::Value &lut_map) {
|
|
auto input_shape = getTensorType(lut_input).getShape();
|
|
auto map_shape = getTensorType(lut_map).getShape();
|
|
if (input_shape.equals(map_shape)) {
|
|
return mlir::success();
|
|
}
|
|
std::string error;
|
|
int input_rank = input_shape.size();
|
|
int map_rank = map_shape.size();
|
|
std::string input_name = "'t' (operand #1)";
|
|
std::string map_name = "'lut_map.getName()' (operand #3)";
|
|
if (input_rank == map_rank) {
|
|
error = ": " + input_name + " dimensions differs from " + map_name;
|
|
} else {
|
|
error = ": " + input_name + " rank (=" + std::to_string(input_rank) +
|
|
") differs from " + map_name +
|
|
" rank (=" + std::to_string(map_rank) + ")";
|
|
}
|
|
op.emitOpError() << error;
|
|
return mlir::failure();
|
|
}
|
|
|
|
mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op,
|
|
::mlir::Value &encryptedIndex,
|
|
::mlir::Value &luts) {
|
|
auto index_width = getEncryptedElmentType(encryptedIndex).getWidth();
|
|
auto actual_lut_size = getTensorType(luts).getShape().back();
|
|
auto expected_lut_size = 1 << index_width;
|
|
if (actual_lut_size == expected_lut_size) {
|
|
return mlir::success();
|
|
}
|
|
FHE::emitErrorBadLutSize(op, "luts", "ct", expected_lut_size, index_width);
|
|
return mlir::failure();
|
|
}
|
|
|
|
mlir::LogicalResult
|
|
verifyApplyMappedLookupTable(ApplyMappedLookupTableEintOp &op) {
|
|
auto t = op.t();
|
|
auto luts = op.luts();
|
|
auto map = op.map();
|
|
auto result = op.getResult();
|
|
|
|
auto t_shape = getTensorType(t).getShape();
|
|
if (!getTensorType(result).hasStaticShape(t_shape)) {
|
|
op.emitOpError()
|
|
<< ": `t` (operand #1) and `map` (operand #2) must have the same shape";
|
|
return mlir::failure();
|
|
}
|
|
|
|
if (!getTensorType(map).getElementType().isIndex()) {
|
|
op.emitOpError()
|
|
<< ": `map` (operand #3) should contains elements of type `index`";
|
|
return mlir::failure();
|
|
}
|
|
|
|
return mlir::success(verifyMapHasRightShape(op, t, map).succeeded() &&
|
|
verifyLutsSize(op, t, luts).succeeded());
|
|
}
|
|
|
|
::mlir::LogicalResult verifyDotEintInt(Dot &op) {
|
|
if (::mlir::failed(mlir::verifyCompatibleShape(op.lhs().getType(),
|
|
op.rhs().getType()))) {
|
|
return op.emitOpError("arguments have incompatible shapes");
|
|
}
|
|
auto lhsEltType = op.lhs()
|
|
.getType()
|
|
.cast<mlir::TensorType>()
|
|
.getElementType()
|
|
.cast<FHE::EncryptedIntegerType>();
|
|
auto rhsEltType = op.rhs()
|
|
.getType()
|
|
.cast<mlir::TensorType>()
|
|
.getElementType()
|
|
.cast<mlir::IntegerType>();
|
|
auto resultType = op.getResult().getType().cast<FHE::EncryptedIntegerType>();
|
|
if (!mlir::concretelang::FHE::
|
|
verifyEncryptedIntegerAndIntegerInputsConsistency(op, lhsEltType,
|
|
rhsEltType)) {
|
|
return ::mlir::failure();
|
|
}
|
|
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(op, lhsEltType,
|
|
resultType)) {
|
|
return ::mlir::failure();
|
|
}
|
|
return ::mlir::success();
|
|
}
|
|
|
|
llvm::SmallVector<int64_t, 3>
|
|
verifySumCalculateActualOutputShape(mlir::Type outputType) {
|
|
auto actualOutputShape = llvm::SmallVector<int64_t, 3>{};
|
|
if (outputType.isa<mlir::TensorType>()) {
|
|
auto outputTensorType = outputType.dyn_cast<mlir::TensorType>();
|
|
for (int64_t size : outputTensorType.getShape()) {
|
|
actualOutputShape.push_back(size);
|
|
}
|
|
}
|
|
return actualOutputShape;
|
|
}
|
|
|
|
llvm::SmallVector<int64_t, 3> verifySumCalculateExpectedOutputShape(
|
|
llvm::ArrayRef<int64_t> inputShape, int64_t inputDimensions,
|
|
std::unordered_set<int64_t> &axesToDestroy, bool keepDims) {
|
|
|
|
auto expectedOutputShape = llvm::SmallVector<int64_t, 3>{};
|
|
for (int64_t i = 0; i < inputDimensions; i++) {
|
|
bool ithAxisIsDestroyed = axesToDestroy.find(i) != axesToDestroy.end();
|
|
if (!ithAxisIsDestroyed) {
|
|
expectedOutputShape.push_back(inputShape[i]);
|
|
} else if (keepDims) {
|
|
expectedOutputShape.push_back(1);
|
|
}
|
|
}
|
|
return expectedOutputShape;
|
|
}
|
|
|
|
mlir::LogicalResult verifySum(SumOp &op) {
|
|
mlir::Value input = op.getOperand();
|
|
mlir::Value output = op.getResult();
|
|
|
|
auto inputType = input.getType().dyn_cast<mlir::TensorType>();
|
|
mlir::Type outputType = output.getType();
|
|
|
|
FHE::EncryptedIntegerType inputElementType =
|
|
inputType.getElementType().dyn_cast<FHE::EncryptedIntegerType>();
|
|
FHE::EncryptedIntegerType outputElementType =
|
|
!outputType.isa<mlir::TensorType>()
|
|
? outputType.dyn_cast<FHE::EncryptedIntegerType>()
|
|
: outputType.dyn_cast<mlir::TensorType>()
|
|
.getElementType()
|
|
.dyn_cast<FHE::EncryptedIntegerType>();
|
|
|
|
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(
|
|
op, inputElementType, outputElementType)) {
|
|
return mlir::failure();
|
|
}
|
|
|
|
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
int64_t inputDimensions = (int64_t)inputShape.size();
|
|
|
|
mlir::ArrayAttr axes = op.axes();
|
|
bool keepDims = op.keep_dims();
|
|
|
|
auto axesToDestroy = std::unordered_set<int64_t>{};
|
|
for (mlir::Attribute axisAttribute : axes) {
|
|
int64_t axis = axisAttribute.cast<mlir::IntegerAttr>().getInt();
|
|
|
|
bool axisIsValid = (0 <= axis) && (axis < inputDimensions);
|
|
if (!axisIsValid) {
|
|
op.emitOpError("has invalid axes attribute");
|
|
return mlir::failure();
|
|
}
|
|
|
|
axesToDestroy.insert(axis);
|
|
}
|
|
if (axesToDestroy.empty()) {
|
|
for (int64_t i = 0; i < inputDimensions; i++) {
|
|
axesToDestroy.insert(i);
|
|
}
|
|
}
|
|
|
|
auto expectedOutputShape = verifySumCalculateExpectedOutputShape(
|
|
inputShape, inputDimensions, axesToDestroy, keepDims);
|
|
auto actualOutputShape = verifySumCalculateActualOutputShape(outputType);
|
|
|
|
if (expectedOutputShape != actualOutputShape) {
|
|
auto stream = op.emitOpError();
|
|
|
|
stream << "does not have the proper output shape of <";
|
|
if (!expectedOutputShape.empty()) {
|
|
stream << expectedOutputShape[0];
|
|
for (size_t i = 1; i < expectedOutputShape.size(); i++) {
|
|
stream << "x" << expectedOutputShape[i];
|
|
}
|
|
}
|
|
stream << ">";
|
|
|
|
return mlir::failure();
|
|
}
|
|
|
|
return mlir::success();
|
|
}
|
|
|
|
static bool sameShapeExceptAxis(llvm::ArrayRef<int64_t> shape1,
|
|
llvm::ArrayRef<int64_t> shape2, size_t axis) {
|
|
if (shape1.size() != shape2.size()) {
|
|
return false;
|
|
}
|
|
for (size_t i = 0; i < shape1.size(); i++) {
|
|
if (i != axis && shape1[i] != shape2[i]) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
mlir::LogicalResult verifyConcat(ConcatOp &op) {
|
|
unsigned numOperands = op.getNumOperands();
|
|
if (numOperands < 2) {
|
|
op->emitOpError() << "should have at least 2 inputs";
|
|
return mlir::failure();
|
|
}
|
|
|
|
int64_t axis = op.axis();
|
|
mlir::Value out = op.out();
|
|
|
|
auto outVectorType = out.getType().dyn_cast<mlir::TensorType>();
|
|
auto outElementType =
|
|
outVectorType.getElementType().dyn_cast<FHE::EncryptedIntegerType>();
|
|
|
|
llvm::ArrayRef<int64_t> outShape = outVectorType.getShape();
|
|
size_t outDims = outShape.size();
|
|
|
|
if (axis < 0 || (size_t)axis >= outDims) {
|
|
op->emitOpError() << "has invalid axis attribute";
|
|
return mlir::failure();
|
|
}
|
|
|
|
int64_t expectedOutputElementsInAxis = 0;
|
|
|
|
size_t index = 0;
|
|
for (mlir::Value in : op.ins()) {
|
|
auto inVectorType = in.getType().dyn_cast<mlir::TensorType>();
|
|
auto inElementType =
|
|
inVectorType.getElementType().dyn_cast<FHE::EncryptedIntegerType>();
|
|
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(op, inElementType,
|
|
outElementType)) {
|
|
return ::mlir::failure();
|
|
}
|
|
|
|
llvm::ArrayRef<int64_t> inShape = inVectorType.getShape();
|
|
if (!sameShapeExceptAxis(inShape, outShape, (size_t)axis)) {
|
|
auto stream = op->emitOpError();
|
|
|
|
stream << "does not have the proper shape of <";
|
|
if (axis == 0) {
|
|
stream << "?";
|
|
} else {
|
|
stream << outShape[0];
|
|
}
|
|
for (size_t i = 1; i < outDims; i++) {
|
|
stream << "x";
|
|
if (i == (size_t)axis) {
|
|
stream << "?";
|
|
} else {
|
|
stream << outShape[i];
|
|
}
|
|
}
|
|
stream << "> for input #" << index;
|
|
|
|
return mlir::failure();
|
|
}
|
|
expectedOutputElementsInAxis += inShape[axis];
|
|
|
|
index += 1;
|
|
}
|
|
|
|
if (outShape[axis] != expectedOutputElementsInAxis) {
|
|
auto stream = op->emitOpError();
|
|
|
|
stream << "does not have the proper output shape of <";
|
|
if (axis == 0) {
|
|
stream << expectedOutputElementsInAxis;
|
|
} else {
|
|
stream << outShape[0];
|
|
}
|
|
for (size_t i = 1; i < outDims; i++) {
|
|
stream << "x";
|
|
if (i == (size_t)axis) {
|
|
stream << expectedOutputElementsInAxis;
|
|
} else {
|
|
stream << outShape[i];
|
|
}
|
|
}
|
|
stream << ">";
|
|
|
|
return mlir::failure();
|
|
}
|
|
|
|
return mlir::success();
|
|
}
|
|
|
|
/// Verify the matmul shapes, the type of tensor elements should be checked by
|
|
/// something else
|
|
template <typename MatMulOp> mlir::LogicalResult verifyMatmul(MatMulOp &op) {
|
|
auto lhsType =
|
|
((mlir::Type)op.lhs().getType()).cast<mlir::RankedTensorType>();
|
|
auto rhsType =
|
|
((mlir::Type)op.rhs().getType()).cast<mlir::RankedTensorType>();
|
|
|
|
llvm::ArrayRef<int64_t> lhsShape = lhsType.getShape();
|
|
llvm::ArrayRef<int64_t> rhsShape = rhsType.getShape();
|
|
|
|
int64_t lhsDims = (int64_t)lhsShape.size();
|
|
int64_t rhsDims = (int64_t)rhsShape.size();
|
|
|
|
auto expectedOutputShape = mlir::SmallVector<int64_t, 2>{};
|
|
if (lhsDims == 2 && rhsDims == 2) {
|
|
|
|
// MxN @ NxP -> MxP
|
|
|
|
if (lhsShape[1] != rhsShape[0]) {
|
|
op.emitOpError() << "should have the same size "
|
|
"on dimension #1 of operand #0 "
|
|
"and dimension #0 of operand #1";
|
|
return mlir::failure();
|
|
}
|
|
|
|
expectedOutputShape.push_back(lhsShape[0]);
|
|
expectedOutputShape.push_back(rhsShape[1]);
|
|
|
|
} else if (lhsDims >= 2 && rhsDims >= 2) {
|
|
|
|
// KxLxMxN @ NxP -> KxLxMxP
|
|
// KxLxMxN @ LxNxP -> KxLxMxP
|
|
// Kx1xMxN @ LxNxP -> KxLxMxP
|
|
|
|
// MxN @ KxLxNxP -> KxLxMxP
|
|
// LxMxN @ KxLxNxP -> KxLxMxP
|
|
// 1xMxN @ KxLxNxP -> KxLxMxP
|
|
|
|
if (lhsShape[lhsDims - 1] != rhsShape[rhsDims - 2]) {
|
|
op.emitOpError() << "should have the same size "
|
|
<< "on dimension #" << lhsDims - 1 << " of operand #0 "
|
|
<< "and dimension #" << rhsDims - 2 << " of operand #1";
|
|
return mlir::failure();
|
|
}
|
|
|
|
auto expectedOutputShapeReversed = mlir::SmallVector<int64_t, 4>{};
|
|
|
|
expectedOutputShapeReversed.push_back(rhsShape[rhsDims - 1]);
|
|
expectedOutputShapeReversed.push_back(lhsShape[lhsDims - 2]);
|
|
|
|
int64_t i = lhsDims - 3;
|
|
int64_t j = rhsDims - 3;
|
|
while (i >= 0 && j >= 0) {
|
|
int64_t lhsSize = lhsShape[i];
|
|
int64_t rhsSize = rhsShape[j];
|
|
|
|
if (lhsSize == rhsSize || lhsSize == 1 || rhsSize == 1) {
|
|
expectedOutputShapeReversed.push_back(std::max(lhsSize, rhsSize));
|
|
} else {
|
|
op.emitOpError() << "should have the same size or size of 1 "
|
|
<< "on dimension #" << i << " of operand #0 "
|
|
<< "and dimension #" << j << " of operand #1";
|
|
return mlir::failure();
|
|
}
|
|
|
|
i--;
|
|
j--;
|
|
}
|
|
while (i >= 0) {
|
|
int64_t lhsSize = lhsShape[i];
|
|
expectedOutputShapeReversed.push_back(lhsSize);
|
|
i--;
|
|
}
|
|
while (j >= 0) {
|
|
int64_t rhsSize = rhsShape[j];
|
|
expectedOutputShapeReversed.push_back(rhsSize);
|
|
j--;
|
|
}
|
|
|
|
while (!expectedOutputShapeReversed.empty()) {
|
|
expectedOutputShape.push_back(expectedOutputShapeReversed.back());
|
|
expectedOutputShapeReversed.pop_back();
|
|
}
|
|
|
|
} else if (lhsDims == 1 && rhsDims >= 2) {
|
|
|
|
// N @ NxP -> P
|
|
// N @ LxNxP -> LxP
|
|
// N @ KxLxNxP -> KxLxP
|
|
|
|
if (rhsShape[rhsDims - 2] != lhsShape[0]) {
|
|
op.emitOpError() << "should have the same size "
|
|
<< "on dimension #0 of operand #0 "
|
|
<< "and dimension #" << rhsDims - 2 << " of operand #1";
|
|
return mlir::failure();
|
|
}
|
|
|
|
for (int64_t i = 0; i < rhsDims; i++) {
|
|
if (i != rhsDims - 2) {
|
|
expectedOutputShape.push_back(rhsShape[i]);
|
|
}
|
|
}
|
|
|
|
} else if (lhsDims >= 2 && rhsDims == 1) {
|
|
|
|
// MxN @ N -> M
|
|
// LxMxN @ N -> LxM
|
|
// KxLxMxN @ N -> KxLxM
|
|
|
|
if (lhsShape[lhsDims - 1] != rhsShape[0]) {
|
|
op.emitOpError() << "should have the same size "
|
|
<< "on dimension #" << lhsDims - 1 << " of operand #0 "
|
|
<< "and dimension #0 of operand #1";
|
|
return mlir::failure();
|
|
}
|
|
|
|
for (int64_t i = 0; i < lhsDims - 1; i++) {
|
|
expectedOutputShape.push_back(lhsShape[i]);
|
|
}
|
|
|
|
} else {
|
|
|
|
// M @ N
|
|
|
|
op.emitOpError() << "should have at least one "
|
|
"multi dimensional tensor "
|
|
"as an operand";
|
|
return mlir::failure();
|
|
}
|
|
|
|
auto resultType =
|
|
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
|
|
|
if (!resultType.hasStaticShape(expectedOutputShape)) {
|
|
auto stream = op->emitOpError();
|
|
|
|
stream << "does not have the proper output shape of ";
|
|
|
|
stream << "<" << expectedOutputShape[0];
|
|
for (size_t i = 1; i < expectedOutputShape.size(); i++) {
|
|
stream << "x" << expectedOutputShape[i];
|
|
}
|
|
stream << ">";
|
|
|
|
return mlir::failure();
|
|
}
|
|
|
|
return mlir::success();
|
|
}
|
|
|
|
mlir::SmallVector<int64_t, 4>
|
|
getPaddingFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
|
mlir::SmallVector<int64_t, 4> paddingInts;
|
|
llvm::Optional<mlir::DenseIntElementsAttr> optionalPadding = convOp.padding();
|
|
if (optionalPadding.hasValue()) {
|
|
auto paddingAttr = optionalPadding.getValue();
|
|
auto paddingAttrShape =
|
|
paddingAttr.getType().cast<RankedTensorType>().getShape();
|
|
assert(paddingAttrShape.size() == 1 && paddingAttrShape[0] == 4 &&
|
|
"incorrect padding shape");
|
|
paddingInts.insert(paddingInts.begin(), paddingAttr.value_begin<int64_t>(),
|
|
paddingAttr.value_end<int64_t>());
|
|
} else {
|
|
paddingInts.insert(paddingInts.begin(), {0, 0, 0, 0});
|
|
}
|
|
return paddingInts;
|
|
}
|
|
|
|
mlir::SmallVector<int64_t, 2>
|
|
getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
|
mlir::SmallVector<int64_t, 2> stridesInts;
|
|
llvm::Optional<mlir::DenseIntElementsAttr> optionalStrides = convOp.strides();
|
|
if (optionalStrides.hasValue()) {
|
|
auto stridesAttr = optionalStrides.getValue();
|
|
auto stridesAttrShape =
|
|
stridesAttr.getType().cast<RankedTensorType>().getShape();
|
|
assert(stridesAttrShape.size() == 1 && stridesAttrShape[0] == 2 &&
|
|
"incorrect strides shape");
|
|
stridesInts.insert(stridesInts.begin(), stridesAttr.value_begin<int64_t>(),
|
|
stridesAttr.value_end<int64_t>());
|
|
} else {
|
|
stridesInts.insert(stridesInts.begin(), {1, 1});
|
|
}
|
|
return stridesInts;
|
|
}
|
|
|
|
mlir::SmallVector<int64_t, 2>
|
|
getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
|
mlir::SmallVector<int64_t, 2> dilationsInts;
|
|
llvm::Optional<mlir::DenseIntElementsAttr> optionalDilations =
|
|
convOp.dilations();
|
|
if (optionalDilations.hasValue()) {
|
|
auto dilationsAttr = optionalDilations.getValue();
|
|
auto dilationsAttrShape =
|
|
dilationsAttr.getType().cast<RankedTensorType>().getShape();
|
|
assert(dilationsAttrShape.size() == 1 && dilationsAttrShape[0] == 2 &&
|
|
"incorrect dilations shape");
|
|
dilationsInts.insert(dilationsInts.begin(),
|
|
dilationsAttr.value_begin<int64_t>(),
|
|
dilationsAttr.value_end<int64_t>());
|
|
} else {
|
|
dilationsInts.insert(dilationsInts.begin(), {1, 1});
|
|
}
|
|
return dilationsInts;
|
|
}
|
|
|
|
/// Verify the Conv2d shapes, attributes, and expected output dimensions
|
|
mlir::LogicalResult
|
|
verifyConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
|
|
auto inputTy =
|
|
((mlir::Type)convOp.input().getType()).cast<mlir::RankedTensorType>();
|
|
auto weightTy =
|
|
((mlir::Type)convOp.weight().getType()).cast<mlir::RankedTensorType>();
|
|
auto resultTy =
|
|
((mlir::Type)convOp.getResult().getType()).cast<mlir::RankedTensorType>();
|
|
auto inputShape = inputTy.getShape();
|
|
auto weightShape = weightTy.getShape();
|
|
auto resultShape = resultTy.getShape();
|
|
|
|
auto p = inputTy.getElementType()
|
|
.cast<mlir::concretelang::FHE::EncryptedIntegerType>()
|
|
.getWidth();
|
|
auto weightElementTyWidth =
|
|
weightTy.getElementType().cast<mlir::IntegerType>().getWidth();
|
|
if (weightElementTyWidth != p + 1) {
|
|
convOp.emitOpError() << "expected weight element type to have width "
|
|
<< p + 1 << " but got " << weightElementTyWidth;
|
|
return mlir::failure();
|
|
}
|
|
|
|
// Checking dimensions
|
|
if (inputShape.size() != 4) {
|
|
convOp.emitOpError() << "input should have 4 dimensions (N*C*H*W) but got "
|
|
<< inputShape.size();
|
|
return mlir::failure();
|
|
}
|
|
if (weightShape.size() != 4) {
|
|
convOp.emitOpError() << "weight should have 4 dimensions (F*C*H*W) but got "
|
|
<< weightShape.size();
|
|
return mlir::failure();
|
|
}
|
|
if (resultShape.size() != 4) {
|
|
convOp.emitOpError() << "result should have 4 dimensions (N*C*H*W) but got "
|
|
<< resultShape.size();
|
|
return mlir::failure();
|
|
}
|
|
|
|
// Checking attributes
|
|
mlir::SmallVector<int64_t, 4> paddingInts = getPaddingFromConv2d(convOp);
|
|
llvm::Optional<mlir::DenseIntElementsAttr> optionalPadding = convOp.padding();
|
|
if (optionalPadding.hasValue()) {
|
|
auto paddingAttr = optionalPadding.getValue();
|
|
auto paddingAttrShape =
|
|
paddingAttr.getType().cast<RankedTensorType>().getShape();
|
|
if (paddingAttrShape.size() != 1 || paddingAttrShape[0] != 4) {
|
|
convOp.emitOpError()
|
|
<< "padding should have a single dimension of size 4, but got shape ["
|
|
<< paddingAttrShape << "]";
|
|
return mlir::failure();
|
|
}
|
|
for (auto i = 0; i < 4; i++) {
|
|
// TODO: Support padding (#427)
|
|
if (paddingInts[i] != 0) {
|
|
convOp.emitOpError()
|
|
<< "padding isn't yet supported, but got a non zero value ("
|
|
<< paddingInts[i] << ") at index " << i;
|
|
return mlir::failure();
|
|
}
|
|
|
|
if (paddingInts[i] < 0) {
|
|
convOp.emitOpError() << "padding can't have a negative value, but got "
|
|
<< paddingInts[i] << " at index " << i;
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
}
|
|
mlir::SmallVector<int64_t, 2> stridesInts = getStridesFromConv2d(convOp);
|
|
llvm::Optional<mlir::DenseIntElementsAttr> optionalStrides = convOp.strides();
|
|
if (optionalStrides.hasValue()) {
|
|
auto stridesAttr = optionalStrides.getValue();
|
|
auto stridesAttrShape =
|
|
stridesAttr.getType().cast<RankedTensorType>().getShape();
|
|
if (stridesAttrShape.size() != 1 || stridesAttrShape[0] != 2) {
|
|
convOp.emitOpError()
|
|
<< "strides should have a single dimension of size 2, but got shape ["
|
|
<< stridesAttrShape << "]";
|
|
return mlir::failure();
|
|
}
|
|
for (auto i = 0; i < 2; i++) {
|
|
if (stridesInts[i] < 1) {
|
|
convOp.emitOpError()
|
|
<< "strides can't have a value less than 1, but got "
|
|
<< stridesInts[i] << " at index " << i;
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
}
|
|
mlir::SmallVector<int64_t, 2> dilationsInts = getDilationsFromConv2d(convOp);
|
|
llvm::Optional<mlir::DenseIntElementsAttr> optionalDilations =
|
|
convOp.dilations();
|
|
if (optionalDilations.hasValue()) {
|
|
auto dilationsAttr = optionalDilations.getValue();
|
|
auto dilationsAttrShape =
|
|
dilationsAttr.getType().cast<RankedTensorType>().getShape();
|
|
if (dilationsAttrShape.size() != 1 || dilationsAttrShape[0] != 2) {
|
|
convOp.emitOpError() << "dilations should have a single dimension of "
|
|
"size 2, but got shape ["
|
|
<< dilationsAttrShape << "]";
|
|
return mlir::failure();
|
|
}
|
|
for (auto i = 0; i < 2; i++) {
|
|
if (dilationsInts[i] < 1) {
|
|
convOp.emitOpError()
|
|
<< "dilations can't have a value less than 1, but got "
|
|
<< dilationsInts[i] << " at index " << i;
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
}
|
|
|
|
// Extracting dimensions
|
|
int64_t inputN = inputShape[0], inputC = inputShape[1],
|
|
inputH = inputShape[2], inputW = inputShape[3];
|
|
int64_t weightF = weightShape[0], weightC = weightShape[1],
|
|
weightH = weightShape[2], weightW = weightShape[3];
|
|
int64_t resultN = resultShape[0], resultC = resultShape[1],
|
|
resultH = resultShape[2], resultW = resultShape[3];
|
|
|
|
// Bias check if specified
|
|
mlir::Value bias = convOp.bias();
|
|
if (bias) {
|
|
auto biasTy = ((mlir::Type)bias.getType()).cast<mlir::RankedTensorType>();
|
|
auto biasShape = biasTy.getShape();
|
|
if (biasShape.size() != 1) {
|
|
convOp.emitOpError() << "bias should have 1 dimension but got "
|
|
<< biasShape.size();
|
|
return mlir::failure();
|
|
}
|
|
if (biasShape[0] != weightF) {
|
|
convOp.emitOpError() << "expected bias vector to have size " << weightF
|
|
<< " but got " << biasShape[0];
|
|
return mlir::failure();
|
|
}
|
|
auto biasElementTyWidth =
|
|
biasTy.getElementType().cast<mlir::IntegerType>().getWidth();
|
|
if (biasElementTyWidth != p + 1) {
|
|
convOp.emitOpError() << "expected bias element type to have width "
|
|
<< p + 1 << " but got " << biasElementTyWidth;
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
|
|
// Dimension sizes checks
|
|
if (resultN != inputN) {
|
|
convOp.emitOpError()
|
|
<< "expected result batch size to be equal to input batch size ("
|
|
<< inputN << ") but got " << resultN;
|
|
return mlir::failure();
|
|
}
|
|
if (inputC != weightC) {
|
|
convOp.emitOpError() << "expected number of channels in weight to be equal "
|
|
"to number of channels in input ("
|
|
<< inputC << ") but got " << weightC;
|
|
return mlir::failure();
|
|
}
|
|
if (weightF != resultC) {
|
|
convOp.emitOpError() << "expected number of output channels to be equal to "
|
|
"the number of filters ("
|
|
<< weightF << ") but got " << resultC;
|
|
return mlir::failure();
|
|
}
|
|
|
|
int64_t paddingH = paddingInts[0] + paddingInts[2];
|
|
int64_t paddingW = paddingInts[1] + paddingInts[3];
|
|
int64_t dilationH = dilationsInts[0];
|
|
int64_t dilationW = dilationsInts[1];
|
|
int64_t strideH = stridesInts[0];
|
|
int64_t strideW = stridesInts[1];
|
|
int64_t expectedResultH =
|
|
floor((inputH + paddingH - dilationH * (weightH - 1) - 1) / strideH) + 1;
|
|
int64_t expectedResultW =
|
|
floor((inputW + paddingW - dilationW * (weightW - 1) - 1) / strideW) + 1;
|
|
|
|
if (expectedResultH != resultH) {
|
|
convOp.emitOpError() << "expected height of output to be equal to "
|
|
<< expectedResultH << " but got " << resultH;
|
|
return mlir::failure();
|
|
}
|
|
if (expectedResultW != resultW) {
|
|
convOp.emitOpError() << "expected width of output to be equal to "
|
|
<< expectedResultW << " but got " << resultW;
|
|
return mlir::failure();
|
|
}
|
|
|
|
return mlir::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Implementation of FhelinalgConv2DNchwFchwOp
|
|
// This is a generated functions from `make generate_conv_op`, and some helpers
|
|
// from LinalgOps.cpp
|
|
//===----------------------------------------------------------------------===//
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
ArrayAttr FhelinalgConv2DNchwFchwOp::iterator_types() {
|
|
return Builder(getContext())
|
|
.getStrArrayAttr(SmallVector<StringRef>{
|
|
getParallelIteratorTypeName(), getParallelIteratorTypeName(),
|
|
getParallelIteratorTypeName(), getParallelIteratorTypeName(),
|
|
getReductionIteratorTypeName(), getReductionIteratorTypeName(),
|
|
getReductionIteratorTypeName()});
|
|
}
|
|
|
|
static SmallVector<AffineExpr>
|
|
getSymbolBindings(FhelinalgConv2DNchwFchwOp self) {
|
|
MLIRContext *context = self.getContext();
|
|
SmallVector<AffineExpr> exprs;
|
|
exprs.push_back(getAffineSymbolExpr(0, context));
|
|
exprs.push_back(getAffineSymbolExpr(1, context));
|
|
exprs.push_back(getAffineSymbolExpr(2, context));
|
|
|
|
int64_t cst3 = self.strides().getValue<int64_t>({0});
|
|
exprs.push_back(getAffineConstantExpr(cst3, context));
|
|
|
|
exprs.push_back(getAffineSymbolExpr(4, context));
|
|
|
|
int64_t cst5 = self.dilations().getValue<int64_t>({0});
|
|
exprs.push_back(getAffineConstantExpr(cst5, context));
|
|
|
|
exprs.push_back(getAffineSymbolExpr(6, context));
|
|
|
|
int64_t cst7 = self.strides().getValue<int64_t>({1});
|
|
exprs.push_back(getAffineConstantExpr(cst7, context));
|
|
|
|
exprs.push_back(getAffineSymbolExpr(8, context));
|
|
|
|
int64_t cst9 = self.dilations().getValue<int64_t>({1});
|
|
exprs.push_back(getAffineConstantExpr(cst9, context));
|
|
|
|
exprs.push_back(getAffineSymbolExpr(10, context));
|
|
return exprs;
|
|
}
|
|
|
|
ArrayAttr FhelinalgConv2DNchwFchwOp::indexing_maps() {
|
|
static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
|
|
ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
|
|
if (cached)
|
|
return cached;
|
|
|
|
MLIRContext *context = getContext();
|
|
auto symbolBindings = getSymbolBindings(*this);
|
|
SmallVector<AffineMap> maps;
|
|
maps.push_back(
|
|
mlir::parseAttribute(
|
|
"affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, "
|
|
"s7, s8, s9, s10] -> (d0, d4, d2 * s3 + d5 * s5, d3 * s7 + d6 * s9)>",
|
|
context)
|
|
.cast<AffineMapAttr>()
|
|
.getValue());
|
|
maps.back() = simplifyAffineMap(
|
|
maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0));
|
|
maps.push_back(mlir::parseAttribute(
|
|
"affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, "
|
|
"s4, s5, s6, s7, s8, s9, s10] -> (d1, d4, d5, d6)>",
|
|
context)
|
|
.cast<AffineMapAttr>()
|
|
.getValue());
|
|
maps.back() = simplifyAffineMap(
|
|
maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0));
|
|
maps.push_back(mlir::parseAttribute(
|
|
"affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, "
|
|
"s4, s5, s6, s7, s8, s9, s10] -> (d0, d1, d2, d3)>",
|
|
context)
|
|
.cast<AffineMapAttr>()
|
|
.getValue());
|
|
maps.back() = simplifyAffineMap(
|
|
maps.back().replaceDimsAndSymbols({}, symbolBindings, 7, 0));
|
|
cached = Builder(context).getAffineMapArrayAttr(maps);
|
|
getOperation()->setAttr(memoizeAttr, cached);
|
|
return cached;
|
|
}
|
|
|
|
unsigned FhelinalgConv2DNchwFchwOp::getNumRegionArgs() { return 3; }
|
|
|
|
std::string FhelinalgConv2DNchwFchwOp::getLibraryCallName() {
|
|
return generateLibraryCallName(getOperation());
|
|
}
|
|
|
|
bool FhelinalgConv2DNchwFchwOp::hasDynamicIndexingMaps() { return true; }
|
|
LogicalResult FhelinalgConv2DNchwFchwOp::verifyIndexingMapRequiredAttributes() {
|
|
Operation *op = getOperation();
|
|
|
|
if (auto attr = op->getAttrOfType<DenseElementsAttr>("strides")) {
|
|
if (!attr.getType().getElementType().isInteger(64))
|
|
return op->emitError("incorrect element type for indexing map required "
|
|
"attribute 'strides'");
|
|
if (attr.getType().getShape() != ArrayRef<int64_t>{2})
|
|
return op->emitError(
|
|
"incorrect shape for indexing map required attribute 'strides'");
|
|
} else {
|
|
return op->emitError("missing indexing map required attribute 'strides'");
|
|
}
|
|
|
|
if (auto attr = op->getAttrOfType<DenseElementsAttr>("dilations")) {
|
|
if (!attr.getType().getElementType().isInteger(64))
|
|
return op->emitError("incorrect element type for indexing map required "
|
|
"attribute 'dilations'");
|
|
if (attr.getType().getShape() != ArrayRef<int64_t>{2})
|
|
return op->emitError(
|
|
"incorrect shape for indexing map required attribute 'dilations'");
|
|
} else {
|
|
return op->emitError("missing indexing map required attribute 'dilations'");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Some helpers were copied from LinalgOps.cpp
|
|
|
|
/// Generic entry point to create the block for the region of a LinalgOp.
|
|
/// This is used by both named structured ops created by ods-gen and by manually
|
|
/// defined C++ ops.
|
|
/// This is used by both builders and parsers.
|
|
/// This function creates the block in the region with arguments corresponding
|
|
/// to the elemental types of `inputTypes` and `outputTypes`. The latter are
|
|
/// asserted to be of ShapedType.
|
|
template <typename NamedStructuredOpType>
|
|
static void fillStructuredOpRegion(
|
|
OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
|
|
TypeRange outputTypes,
|
|
std::function<void(unsigned, unsigned)> errorHandler = nullptr);
|
|
|
|
/// Generic entry point to create both the region and the block of a LinalgOp.
|
|
template <typename NamedStructuredOpType>
|
|
static void
|
|
createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
|
|
TypeRange inputTypes, TypeRange outputTypes);
|
|
|
|
/// Common parsing and printing used for both named structured ops created by
|
|
/// ods-gen and by manually defined C++ ops. Does not handle regions.
|
|
static ParseResult
|
|
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
|
SmallVectorImpl<Type> &inputTypes,
|
|
SmallVectorImpl<Type> &outputTypes);
|
|
template <typename NamedStructuredOpType>
|
|
static void printCommonStructuredOpParts(OpAsmPrinter &p,
|
|
NamedStructuredOpType op);
|
|
|
|
/// Specific parsing and printing for named structured ops created by ods-gen.
|
|
template <typename NamedStructuredOpType>
|
|
static ParseResult
|
|
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
|
|
TypeRange inputTypes, TypeRange outputTypes);
|
|
|
|
static ParseResult
|
|
parseNamedStructuredOpResults(OpAsmParser &parser,
|
|
SmallVectorImpl<Type> &resultTypes);
|
|
|
|
template <typename NamedStructuredOpType>
|
|
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
|
OperationState &result);
|
|
|
|
static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
|
TypeRange resultTypes);
|
|
|
|
template <typename NamedStructuredOpType>
|
|
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
|
|
|
|
class RegionBuilderHelper {
|
|
public:
|
|
RegionBuilderHelper(MLIRContext *context, Block &block)
|
|
: context(context), block(block) {}
|
|
|
|
// Generates operations to cast the given operand to a specified type.
|
|
// If the cast cannot be performed, a warning will be issued and the
|
|
// operand returned as-is (which will presumably yield a verification
|
|
// issue downstream).
|
|
Value cast(Type toType, Value operand, bool isUnsignedCast) {
|
|
OpBuilder builder = getBuilder();
|
|
auto loc = operand.getLoc();
|
|
|
|
if (operand.getType() == toType)
|
|
return operand;
|
|
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
|
|
// If operand is floating point, cast directly to the int type.
|
|
if (operand.getType().isa<FloatType>()) {
|
|
if (isUnsignedCast)
|
|
return builder.create<arith::FPToUIOp>(loc, toType, operand);
|
|
return builder.create<arith::FPToSIOp>(loc, toType, operand);
|
|
}
|
|
// Cast index operands directly to the int type.
|
|
if (operand.getType().isIndex())
|
|
return builder.create<arith::IndexCastOp>(loc, toType, operand);
|
|
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
|
|
// Either extend or truncate.
|
|
if (toIntType.getWidth() > fromIntType.getWidth()) {
|
|
if (isUnsignedCast)
|
|
return builder.create<arith::ExtUIOp>(loc, toType, operand);
|
|
return builder.create<arith::ExtSIOp>(loc, toType, operand);
|
|
}
|
|
if (toIntType.getWidth() < fromIntType.getWidth())
|
|
return builder.create<arith::TruncIOp>(loc, toType, operand);
|
|
}
|
|
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
|
|
// If operand is integer, cast directly to the float type.
|
|
// Note that it is unclear how to cast from BF16<->FP16.
|
|
if (operand.getType().isa<IntegerType>()) {
|
|
if (isUnsignedCast)
|
|
return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
|
|
return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
|
|
}
|
|
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
|
|
if (toFloatType.getWidth() > fromFloatType.getWidth())
|
|
return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
|
|
if (toFloatType.getWidth() < fromFloatType.getWidth())
|
|
return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
|
|
}
|
|
}
|
|
|
|
emitWarning(operand.getLoc()) << "could not cast operand of type "
|
|
<< operand.getType() << " to " << toType;
|
|
return operand;
|
|
}
|
|
|
|
Value applyfn__add(Value lhs, Value rhs) {
|
|
OpBuilder builder = getBuilder();
|
|
if (isFloatingPoint(lhs))
|
|
return builder.create<arith::AddFOp>(lhs.getLoc(), lhs, rhs);
|
|
if (isInteger(lhs) &&
|
|
rhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
|
return builder.create<mlir::concretelang::FHE::AddEintIntOp>(lhs.getLoc(),
|
|
rhs, lhs);
|
|
}
|
|
if (lhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() &&
|
|
isInteger(rhs)) {
|
|
return builder.create<mlir::concretelang::FHE::AddEintIntOp>(lhs.getLoc(),
|
|
lhs, rhs);
|
|
}
|
|
if (lhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() &&
|
|
rhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>()) {
|
|
return builder.create<mlir::concretelang::FHE::AddEintOp>(lhs.getLoc(),
|
|
lhs, rhs);
|
|
}
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
|
|
Value applyfn__exp(Value x) {
|
|
OpBuilder builder = getBuilder();
|
|
if (isFloatingPoint(x))
|
|
return builder.create<math::ExpOp>(x.getLoc(), x);
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
|
|
Value applyfn__log(Value x) {
|
|
OpBuilder builder = getBuilder();
|
|
if (isFloatingPoint(x))
|
|
return builder.create<math::LogOp>(x.getLoc(), x);
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
|
|
Value applyfn__sub(Value lhs, Value rhs) {
|
|
OpBuilder builder = getBuilder();
|
|
if (isFloatingPoint(lhs))
|
|
return builder.create<arith::SubFOp>(lhs.getLoc(), lhs, rhs);
|
|
if (isInteger(lhs))
|
|
return builder.create<arith::SubIOp>(lhs.getLoc(), lhs, rhs);
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
|
|
Value applyfn__mul(Value lhs, Value rhs) {
|
|
OpBuilder builder = getBuilder();
|
|
if (isFloatingPoint(lhs))
|
|
return builder.create<arith::MulFOp>(lhs.getLoc(), lhs, rhs);
|
|
if (isInteger(lhs))
|
|
return builder.create<arith::MulIOp>(lhs.getLoc(), lhs, rhs);
|
|
if (lhs.getType().isa<mlir::concretelang::FHE::EncryptedIntegerType>() &&
|
|
isInteger(rhs)) {
|
|
return builder.create<mlir::concretelang::FHE::MulEintIntOp>(lhs.getLoc(),
|
|
lhs, rhs);
|
|
}
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
|
|
Value applyfn__max(Value lhs, Value rhs) {
|
|
OpBuilder builder = getBuilder();
|
|
if (isFloatingPoint(lhs))
|
|
return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
|
|
if (isInteger(lhs))
|
|
return builder.create<MaxSIOp>(lhs.getLoc(), lhs, rhs);
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
|
|
Value applyfn__max_unsigned(Value lhs, Value rhs) {
|
|
OpBuilder builder = getBuilder();
|
|
if (isFloatingPoint(lhs))
|
|
return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
|
|
if (isInteger(lhs))
|
|
return builder.create<MaxUIOp>(lhs.getLoc(), lhs, rhs);
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
|
|
Value applyfn__min(Value lhs, Value rhs) {
|
|
OpBuilder builder = getBuilder();
|
|
if (isFloatingPoint(lhs))
|
|
return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
|
|
if (isInteger(lhs))
|
|
return builder.create<MinSIOp>(lhs.getLoc(), lhs, rhs);
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
|
|
Value applyfn__min_unsigned(Value lhs, Value rhs) {
|
|
OpBuilder builder = getBuilder();
|
|
if (isFloatingPoint(lhs))
|
|
return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
|
|
if (isInteger(lhs))
|
|
return builder.create<MinUIOp>(lhs.getLoc(), lhs, rhs);
|
|
llvm_unreachable("unsupported non numeric type");
|
|
}
|
|
|
|
void yieldOutputs(ValueRange values) {
|
|
assert(!values.empty() && "linalg ops must yield outputs");
|
|
if (values.empty())
|
|
return;
|
|
Value first = values.front();
|
|
OpBuilder builder = getBuilder();
|
|
builder.create<YieldOp>(first.getLoc(), values);
|
|
}
|
|
|
|
Value constant(std::string value) {
|
|
OpBuilder builder = getBuilder();
|
|
Location loc = builder.getUnknownLoc();
|
|
Attribute valueAttr = parseAttribute(value, builder.getContext());
|
|
return builder.create<arith::ConstantOp>(loc, valueAttr.getType(),
|
|
valueAttr);
|
|
}
|
|
|
|
Value index(int64_t dim) {
|
|
OpBuilder builder = getBuilder();
|
|
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
|
|
}
|
|
|
|
Type getIntegerType(unsigned width) {
|
|
return IntegerType::get(context, width);
|
|
}
|
|
|
|
Type getFloat32Type() { return Float32Type::get(context); }
|
|
|
|
Type getFloat64Type() { return Float64Type::get(context); }
|
|
|
|
private:
|
|
MLIRContext *context;
|
|
Block █
|
|
|
|
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
|
|
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
|
|
|
|
OpBuilder getBuilder() {
|
|
OpBuilder builder(context);
|
|
builder.setInsertionPointToEnd(&block);
|
|
return builder;
|
|
}
|
|
};
|
|
|
|
static LogicalResult foldMemRefCast(Operation *op) {
|
|
bool folded = false;
|
|
for (OpOperand &operand : op->getOpOperands()) {
|
|
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
|
|
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
|
|
operand.set(castOp.getOperand());
|
|
folded = true;
|
|
}
|
|
}
|
|
return success(folded);
|
|
}
|
|
|
|
/// Generic entry point to create the block for the region of a LinalgOp.
|
|
/// This is used by both named structured ops created by ods-gen and by manually
|
|
/// defined C++ ops.
|
|
/// This is used by both builders and parsers.
|
|
/// This function creates the block in the region with arguments corresponding
|
|
/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
|
|
/// to be ShapedType.
|
|
template <typename NamedStructuredOpType>
|
|
static void
|
|
fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
|
TypeRange inputTypes, TypeRange outputTypes,
|
|
std::function<void(unsigned, unsigned)> errorHandler) {
|
|
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
|
|
|
|
// TODO: atm all operands go through getElementTypeOrSelf,
|
|
// reconsider when we have evidence we need to.
|
|
SmallVector<Type, 8> argTypes;
|
|
for (auto containers : {inputTypes, outputTypes})
|
|
for (auto t : containers)
|
|
argTypes.push_back(getElementTypeOrSelf(t));
|
|
|
|
// RAII.
|
|
OpBuilder::InsertionGuard guard(opBuilder);
|
|
Block *body = opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes);
|
|
unsigned actual = body->getNumArguments();
|
|
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
|
|
if (expected != actual) {
|
|
if (errorHandler)
|
|
errorHandler(expected, actual);
|
|
return;
|
|
}
|
|
|
|
opBuilder.setInsertionPointToStart(body);
|
|
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
|
|
NamedStructuredOpType::regionBuilder(b, *body);
|
|
|
|
// indexing_maps is an auto-generated method.
|
|
|
|
// iterator_types is an auto-generated method.
|
|
}
|
|
|
|
static void getGenericEffectsImpl(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects,
|
|
ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
|
|
for (Value value : results) {
|
|
effects.emplace_back(MemoryEffects::Allocate::get(), value,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
for (Value value : inputBuffers) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), value,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
for (Value value : outputs) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), value,
|
|
SideEffects::DefaultResource::get());
|
|
effects.emplace_back(MemoryEffects::Write::get(), value,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
}
|
|
|
|
/// Generic entry point to create both the region and the block of a LinalgOp.
|
|
template <typename NamedStructuredOpType>
|
|
void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
|
|
OperationState &result,
|
|
TypeRange inputTypes,
|
|
TypeRange outputTypes) {
|
|
Region ®ion = *result.addRegion();
|
|
fillStructuredOpRegion<NamedStructuredOpType>(
|
|
opBuilder, region, inputTypes, outputTypes,
|
|
[&](unsigned expected, unsigned actual) {
|
|
assert(expected != actual && "incorrect number of arguments");
|
|
});
|
|
}
|
|
|
|
static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
|
TypeRange resultTypes) {
|
|
if (resultTypes.empty())
|
|
return;
|
|
p.printOptionalArrowTypeList(resultTypes);
|
|
}
|
|
|
|
template <typename NamedStructuredOpType>
|
|
static void printCommonStructuredOpParts(OpAsmPrinter &p,
|
|
NamedStructuredOpType op) {
|
|
if (!op.inputs().empty())
|
|
p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
|
|
if (!op.outputs().empty())
|
|
p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
|
|
}
|
|
|
|
template <typename NamedStructuredOpType>
|
|
static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
|
|
p.printOptionalAttrDict(
|
|
op->getAttrs(),
|
|
/*elidedAttrs=*/{"operand_segment_sizes",
|
|
// See generated code in mlir-linalg-yaml-gen.cpp
|
|
"linalg.memoized_indexing_maps"});
|
|
|
|
// Printing is shared with generic ops, except for the region and
|
|
// attributes.
|
|
printCommonStructuredOpParts(p, op);
|
|
|
|
// Results printing.
|
|
printNamedStructuredOpResults(p, op.result_tensors().getTypes());
|
|
|
|
// Region is elided.
|
|
}
|
|
|
|
/// Common parsing used for both named structured ops created by ods-gen and by
|
|
/// manually defined C++ ops. Does not handle regions.
|
|
static ParseResult
|
|
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
|
SmallVectorImpl<Type> &inputTypes,
|
|
SmallVectorImpl<Type> &outputTypes) {
|
|
llvm::SMLoc inputsOperandsLoc, outputsOperandsLoc;
|
|
SmallVector<OpAsmParser::OperandType, 4> inputsOperands, outputsOperands;
|
|
|
|
parser.parseOptionalAttrDict(result.attributes);
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("ins"))) {
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
|
|
inputsOperandsLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperandList(inputsOperands) ||
|
|
parser.parseColonTypeList(inputTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("outs"))) {
|
|
outputsOperandsLoc = parser.getCurrentLocation();
|
|
if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
|
|
parser.parseColonTypeList(outputTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
|
|
result.operands) ||
|
|
parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
|
|
result.operands))
|
|
return failure();
|
|
|
|
result.addAttribute("operand_segment_sizes",
|
|
parser.getBuilder().getI32VectorAttr(
|
|
{static_cast<int32_t>(inputsOperands.size()),
|
|
static_cast<int32_t>(outputsOperands.size())}));
|
|
return success();
|
|
}
|
|
|
|
template <typename NamedStructuredOpType>
|
|
static ParseResult
|
|
parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
|
|
TypeRange inputTypes, TypeRange outputTypes) {
|
|
ParseResult res = success();
|
|
OpBuilder opBuilder(parser.getContext());
|
|
// Resolve `captures` into `capturedValues` at parse time so we can build the
|
|
// region with captures.
|
|
SmallVector<Value> capturedValues;
|
|
fillStructuredOpRegion<NamedStructuredOpType>(
|
|
opBuilder, region, inputTypes, outputTypes,
|
|
[&](unsigned expected, unsigned actual) {
|
|
res = parser.emitError(
|
|
parser.getCurrentLocation(),
|
|
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
|
|
"region expects {0} args, got {1}",
|
|
expected, actual));
|
|
region.front().dump();
|
|
});
|
|
return res;
|
|
}
|
|
|
|
static ParseResult
|
|
parseNamedStructuredOpResults(OpAsmParser &parser,
|
|
SmallVectorImpl<Type> &resultTypes) {
|
|
if (parser.parseOptionalArrowTypeList(resultTypes))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
template <typename NamedStructuredOpType>
|
|
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
|
OperationState &result) {
|
|
// TODO: Enable when ods-gen supports captures.
|
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
|
return failure();
|
|
|
|
// TODO: consider merging results parsing into region parsing.
|
|
// Need to wait for declarative assembly resolution to decide.
|
|
SmallVector<Type, 1> outputTensorsTypes;
|
|
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
|
|
return failure();
|
|
result.addTypes(outputTensorsTypes);
|
|
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
|
|
parser, *region, inputTypes, outputTypes))
|
|
return failure();
|
|
result.addRegion(std::move(region));
|
|
|
|
return success();
|
|
}
|
|
|
|
/// END OF COPY FROM LinalgOps.cpp
|
|
|
|
void FhelinalgConv2DNchwFchwOp::regionBuilder(ImplicitLocOpBuilder &b,
|
|
Block &block) {
|
|
assert(3 > 0 && block.getNumArguments() == 3 &&
|
|
"FhelinalgConv2DNchwFchwOp regionBuilder expects 3 (>=0) args");
|
|
RegionBuilderHelper helper(block.getArgument(0).getContext(), block);
|
|
SmallVector<Value> yields;
|
|
Value value1 =
|
|
helper.cast(block.getArgument(0).getType(), block.getArgument(0), false);
|
|
Value value2 =
|
|
helper.cast(block.getArgument(1).getType(), block.getArgument(1), false);
|
|
Value value3 = helper.applyfn__mul(value1, value2);
|
|
Value value4 = helper.applyfn__add(block.getArgument(2), value3);
|
|
yields.push_back(value4);
|
|
helper.yieldOutputs(yields);
|
|
}
|
|
|
|
LogicalResult FhelinalgConv2DNchwFchwOp::fold(ArrayRef<Attribute>,
|
|
SmallVectorImpl<OpFoldResult> &) {
|
|
return foldMemRefCast(*this);
|
|
}
|
|
void FhelinalgConv2DNchwFchwOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
SmallVector<Value> inputBuffers = getInputBufferOperands();
|
|
SmallVector<Value> outputBuffers = getOutputBufferOperands();
|
|
getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
|
|
outputBuffers);
|
|
}
|
|
|
|
/// Verify the transpose shapes
|
|
mlir::LogicalResult verifyTranspose(TransposeOp &transposeOp) {
|
|
mlir::Type tensorTy = ((mlir::Type)transposeOp.tensor().getType());
|
|
if (!tensorTy.isa<RankedTensorType>()) {
|
|
transposeOp->emitOpError() << "should have operand as tensor";
|
|
return mlir::failure();
|
|
}
|
|
mlir::Type resultTy = ((mlir::Type)transposeOp.getResult().getType());
|
|
if (!resultTy.isa<RankedTensorType>()) {
|
|
transposeOp->emitOpError() << "should have result as tensor";
|
|
return mlir::failure();
|
|
}
|
|
auto tensorShapedTy = tensorTy.dyn_cast_or_null<mlir::ShapedType>();
|
|
auto resultShapedTy = resultTy.dyn_cast_or_null<mlir::ShapedType>();
|
|
if (tensorShapedTy.getShape().size() != resultShapedTy.getShape().size()) {
|
|
transposeOp.emitOpError()
|
|
<< "input and output tensors should have the same number of dimensions";
|
|
return mlir::failure();
|
|
}
|
|
if (tensorShapedTy.getElementType() != resultShapedTy.getElementType()) {
|
|
transposeOp.emitOpError()
|
|
<< "input and output tensors should have the same element type";
|
|
return mlir::failure();
|
|
}
|
|
size_t n_dims = tensorShapedTy.getShape().size();
|
|
for (size_t i = 0; i < n_dims; i++) {
|
|
if (tensorShapedTy.getDimSize(i) !=
|
|
resultShapedTy.getDimSize(n_dims - (i + 1))) {
|
|
transposeOp.emitOpError()
|
|
<< "output tensor should have inverted dimensions of input";
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
return mlir::success();
|
|
}
|
|
|
|
} // namespace FHELinalg
|
|
} // namespace concretelang
|
|
} // namespace mlir
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.cpp.inc"
|