fix(compiler): Fix lowering of tensor.from_elements with crt

This commit is contained in:
Quentin Bourgerie
2022-11-24 10:53:10 +01:00
parent 5d89ad0f84
commit 3bade6603a
2 changed files with 15 additions and 4 deletions

View File

@@ -777,6 +777,8 @@ struct FromElementsOpPattern
if (converter.isLegal(resultTy)) {
return mlir::failure();
}
auto oldTensorResultTy = resultTy.cast<mlir::RankedTensorType>();
auto oldRank = oldTensorResultTy.getRank();
auto newTensorResultTy =
converter.convertType(resultTy).cast<mlir::RankedTensorType>();
@@ -786,11 +788,12 @@ struct FromElementsOpPattern
mlir::Value tensor = rewriter.create<mlir::bufferization::AllocTensorOp>(
fromElementsOp.getLoc(), newTensorResultTy, mlir::ValueRange{});
// sizes are [1, ..., 1, lweSize]
llvm::SmallVector<mlir::OpFoldResult> sizes(newRank - 1,
// sizes are [1, ..., 1, diffShape...]
llvm::SmallVector<mlir::OpFoldResult> sizes(oldRank,
rewriter.getI64IntegerAttr(1));
sizes.push_back(
rewriter.getI64IntegerAttr(*(newTensorResultTy.getShape().end() - 1)));
for (auto i = newRank - oldRank; i > 0; i--) {
sizes.push_back(rewriter.getI64IntegerAttr(*(newShape.end() - i)));
}
// strides are [1, ..., 1]
llvm::SmallVector<mlir::OpFoldResult> oneStrides(

View File

@@ -0,0 +1,8 @@
// RUN: concretecompiler --action=dump-llvm-ir %s
// Just ensure that compile
// https://github.com/zama-ai/concrete-compiler-internal/issues/785
func.func @main(%arg0: !FHE.eint<15>, %cst: tensor<32768xi64>) -> tensor<1x!FHE.eint<15>> {
%1 = "FHE.apply_lookup_table"(%arg0, %cst) : (!FHE.eint<15>, tensor<32768xi64>) -> !FHE.eint<15>
%6 = tensor.from_elements %1 : tensor<1x!FHE.eint<15>> // ERROR HERE line 4
return %6 : tensor<1x!FHE.eint<15>>
}