mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(compiler): Add support for full unrolling of loops with SDFG-convertible ops
This adds a new option `--unroll-loops-with-sdfg-convertible-ops`, which causes loops containing SDFG-convertible operations to be fully unrolled upon the extraction of SDFG-operations using the `--emit-sdfg-ops` switch. This avoids constant roundtrips between an SDFG-capable accelerator and the host during execution of a loop. The option is limited to `scf.for` loops with static bounds and a static step size. Since full unrolling of loops with large bounds results in a large number of operations, the option is disabled by default.
This commit is contained in:
@@ -11,7 +11,8 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<mlir::func::FuncOp>> createExtractSDFGOpsPass();
|
||||
std::unique_ptr<OperationPass<mlir::func::FuncOp>>
|
||||
createExtractSDFGOpsPass(bool unroll);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ struct CompilationOptions {
|
||||
bool loopParallelize;
|
||||
bool batchConcreteOps;
|
||||
bool emitSDFGOps;
|
||||
bool unrollLoopsWithSDFGConvertibleOps;
|
||||
bool dataflowParallelize;
|
||||
bool optimizeConcrete;
|
||||
/// use GPU during execution by generating GPU operations if possible
|
||||
@@ -73,8 +74,9 @@ struct CompilationOptions {
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(llvm::None), verifyDiagnostics(false),
|
||||
autoParallelize(false), loopParallelize(false), batchConcreteOps(false),
|
||||
emitSDFGOps(false), dataflowParallelize(false), optimizeConcrete(true),
|
||||
emitGPUOps(false), clientParametersFuncName(llvm::None),
|
||||
emitSDFGOps(false), unrollLoopsWithSDFGConvertibleOps(false),
|
||||
dataflowParallelize(false), optimizeConcrete(true), emitGPUOps(false),
|
||||
clientParametersFuncName(llvm::None),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG){};
|
||||
|
||||
CompilationOptions(std::string funcname) : CompilationOptions() {
|
||||
|
||||
@@ -58,9 +58,10 @@ mlir::LogicalResult
|
||||
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool unrollLoops);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
@@ -12,6 +12,7 @@ add_mlir_dialect_library(
|
||||
ConcretelangSDFGInterfaces
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRSCFUtils
|
||||
MLIRTransforms)
|
||||
|
||||
target_link_libraries(ExtractSDFGOps PUBLIC MLIRIR)
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
|
||||
#include "concretelang/Dialect/SDFG/Interfaces/SDFGConvertibleInterface.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -18,6 +19,8 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/SCF/IR/SCF.h>
|
||||
|
||||
namespace SDFG = mlir::concretelang::SDFG;
|
||||
|
||||
@@ -34,6 +37,51 @@ SDFG::MakeStream makeStream(mlir::ImplicitLocOpBuilder &builder,
|
||||
return builder.create<SDFG::MakeStream>(streamType, dfg, name, kind);
|
||||
}
|
||||
|
||||
/// Unrolls entirely all scf loops, which directly contain an
|
||||
/// SDFG-convertible operation and whose bounds are static.
|
||||
void unrollLoopsWithSDFGConvertibleOps(mlir::func::FuncOp func) {
|
||||
mlir::DenseSet<mlir::scf::ForOp> unrollCandidates;
|
||||
|
||||
// Identify loops with SDFG-convertible ops
|
||||
func.walk([&](SDFG::SDFGConvertibleOpInterface convertible) {
|
||||
for (mlir::Operation *parent = convertible->getParentOp(); parent;
|
||||
parent = parent->getParentOp()) {
|
||||
if (mlir::scf::ForOp forOp = llvm::dyn_cast<mlir::scf::ForOp>(parent)) {
|
||||
unrollCandidates.insert(forOp);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Fully unroll all the loops if its bounds are static
|
||||
for (mlir::scf::ForOp forOp : unrollCandidates) {
|
||||
mlir::arith::ConstantIndexOp lb =
|
||||
forOp.getLowerBound().getDefiningOp<mlir::arith::ConstantIndexOp>();
|
||||
mlir::arith::ConstantIndexOp ub =
|
||||
forOp.getUpperBound().getDefiningOp<mlir::arith::ConstantIndexOp>();
|
||||
mlir::arith::ConstantIndexOp step =
|
||||
forOp.getStep().getDefiningOp<mlir::arith::ConstantIndexOp>();
|
||||
|
||||
if (!lb || !ub || !step)
|
||||
continue;
|
||||
|
||||
int64_t ilb = lb.value();
|
||||
int64_t iub = ub.value();
|
||||
int64_t istep = step.value();
|
||||
|
||||
// Unrolling requires positive bounds and step
|
||||
if (ilb < 0 || iub < 0 || istep <= 0)
|
||||
continue;
|
||||
|
||||
int64_t unrollFactor = ((iub - ilb) + (istep - 1)) / istep;
|
||||
|
||||
if (unrollFactor == 0)
|
||||
continue;
|
||||
|
||||
if (mlir::loopUnrollByFactor(forOp, (uint64_t)unrollFactor).failed())
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
StreamMappingKind determineStreamMappingKind(mlir::Value v) {
|
||||
// Determine stream type for operands:
|
||||
//
|
||||
@@ -90,11 +138,16 @@ void setInsertionPointAfterValueOrRestore(mlir::OpBuilder &builder,
|
||||
}
|
||||
|
||||
struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase<ExtractSDFGOpsPass> {
|
||||
bool unroll;
|
||||
|
||||
ExtractSDFGOpsPass() {}
|
||||
ExtractSDFGOpsPass(bool unroll) : unroll(unroll) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
mlir::func::FuncOp func = getOperation();
|
||||
|
||||
if (unroll)
|
||||
unrollLoopsWithSDFGConvertibleOps(func);
|
||||
|
||||
mlir::IRRewriter rewriter(func.getContext());
|
||||
|
||||
mlir::DenseMap<mlir::Value, SDFG::MakeStream> processOutMapping;
|
||||
@@ -205,8 +258,9 @@ struct ExtractSDFGOpsPass : public ExtractSDFGOpsBase<ExtractSDFGOpsPass> {
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
std::unique_ptr<OperationPass<mlir::func::FuncOp>> createExtractSDFGOpsPass() {
|
||||
return std::make_unique<ExtractSDFGOpsPass>();
|
||||
std::unique_ptr<OperationPass<mlir::func::FuncOp>>
|
||||
createExtractSDFGOpsPass(bool unroll) {
|
||||
return std::make_unique<ExtractSDFGOpsPass>(unroll);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -399,8 +399,9 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
// Extract SDFG data flow graph from BConcrete representation
|
||||
|
||||
if (options.emitSDFGOps) {
|
||||
if (mlir::concretelang::pipeline::extractSDFGOps(mlirContext, module,
|
||||
enablePass)
|
||||
if (mlir::concretelang::pipeline::extractSDFGOps(
|
||||
mlirContext, module, enablePass,
|
||||
options.unrollLoopsWithSDFGConvertibleOps)
|
||||
.failed()) {
|
||||
return errorDiag(
|
||||
"Extraction of SDFG operations from BConcrete representation failed");
|
||||
|
||||
@@ -274,14 +274,17 @@ lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
extractSDFGOps(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass) {
|
||||
mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
|
||||
mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool unroll) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("extract SDFG ops from BConcrete", pm, context);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createExtractSDFGOpsPass(),
|
||||
enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createExtractSDFGOpsPass(unroll), enablePass);
|
||||
LogicalResult res = pm.run(module.getOperation());
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
|
||||
@@ -183,6 +183,12 @@ llvm::cl::opt<bool> emitSDFGOps(
|
||||
" graphs and emit them."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> unrollLoopsWithSDFGConvertibleOps(
|
||||
"unroll-loops-with-sdfg-convertible-ops",
|
||||
llvm::cl::desc("Causes loops containing SDFG-convertible operations to be "
|
||||
"fully unrolled."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
llvm::cl::opt<bool> dataflowParallelize(
|
||||
"parallelize-dataflow",
|
||||
llvm::cl::desc(
|
||||
@@ -316,6 +322,8 @@ cmdlineCompilationOptions() {
|
||||
options.dataflowParallelize = cmdline::dataflowParallelize;
|
||||
options.batchConcreteOps = cmdline::batchConcreteOps;
|
||||
options.emitSDFGOps = cmdline::emitSDFGOps;
|
||||
options.unrollLoopsWithSDFGConvertibleOps =
|
||||
cmdline::unrollLoopsWithSDFGConvertibleOps;
|
||||
options.optimizeConcrete = cmdline::optimizeConcrete;
|
||||
options.emitGPUOps = cmdline::emitGPUOps;
|
||||
|
||||
|
||||
72
compiler/tests/check_tests/Dialect/SDFG/unrolling.mlir
Normal file
72
compiler/tests/check_tests/Dialect/SDFG/unrolling.mlir
Normal file
@@ -0,0 +1,72 @@
|
||||
// RUN: concretecompiler --action=dump-sdfg --emit-sdfg-ops --unroll-loops-with-sdfg-convertible-ops --split-input-file %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @main(%[[Varg0:.*]]: tensor<4x513xi64>, %[[Varg1:.*]]: tensor<4x513xi64>) -> tensor<4x513xi64> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "SDFG.init"() : () -> !SDFG.dfg
|
||||
// CHECK-NEXT: %[[V1:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream0", type = #SDFG.stream_kind<device_to_host>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream1", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream2", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: "SDFG.make_process"(%[[V0]], %[[V2]], %[[V3]], %[[V1]]) {type = #SDFG.process_kind<add_eint>} : (!SDFG.dfg, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>) -> ()
|
||||
// CHECK-NEXT: %[[V4:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream3", type = #SDFG.stream_kind<device_to_host>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: %[[V5:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream4", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: %[[V6:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream5", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: "SDFG.make_process"(%[[V0]], %[[V5]], %[[V6]], %[[V4]]) {type = #SDFG.process_kind<add_eint>} : (!SDFG.dfg, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>) -> ()
|
||||
// CHECK-NEXT: %[[V7:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream6", type = #SDFG.stream_kind<device_to_host>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: %[[V8:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream7", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: %[[V9:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream8", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: "SDFG.make_process"(%[[V0]], %[[V8]], %[[V9]], %[[V7]]) {type = #SDFG.process_kind<add_eint>} : (!SDFG.dfg, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>) -> ()
|
||||
// CHECK-NEXT: %[[V10:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream9", type = #SDFG.stream_kind<device_to_host>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: %[[V11:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream10", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: %[[V12:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "stream11", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<513xi64>>
|
||||
// CHECK-NEXT: "SDFG.make_process"(%[[V0]], %[[V11]], %[[V12]], %[[V10]]) {type = #SDFG.process_kind<add_eint>} : (!SDFG.dfg, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>, !SDFG.stream<tensor<513xi64>>) -> ()
|
||||
// CHECK-NEXT: "SDFG.start"(%[[V0]]) : (!SDFG.dfg) -> ()
|
||||
// CHECK-NEXT: %[[Vc0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[Vc1:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[V13:.*]] = bufferization.alloc_tensor() : tensor<4x513xi64>
|
||||
// CHECK-NEXT: %[[V14:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[Vc0]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
|
||||
// CHECK-NEXT: %[[V15:.*]] = tensor.collapse_shape %[[V14]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
|
||||
// CHECK-NEXT: "SDFG.put"(%[[V2]], %[[V15]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V16:.*]] = tensor.extract_slice %[[Varg1]]{{\[}}%[[Vc0]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
|
||||
// CHECK-NEXT: %[[V17:.*]] = tensor.collapse_shape %[[V16]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
|
||||
// CHECK-NEXT: "SDFG.put"(%[[V3]], %[[V17]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V18:.*]] = "SDFG.get"(%[[V1]]) : (!SDFG.stream<tensor<513xi64>>) -> tensor<513xi64>
|
||||
// CHECK-NEXT: %[[V19:.*]] = tensor.insert_slice %[[V18]] into %[[V13]]{{\[}}%[[Vc0]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<513xi64> into tensor<4x513xi64>
|
||||
// CHECK-NEXT: %[[Vc1_0:.*]] = arith.constant 1 : index
|
||||
// CHECK-NEXT: %[[V20:.*]] = arith.muli %[[Vc1]], %[[Vc1_0]] : index
|
||||
// CHECK-NEXT: %[[V21:.*]] = arith.addi %[[Vc0]], %[[V20]] : index
|
||||
// CHECK-NEXT: %[[V22:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V21]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
|
||||
// CHECK-NEXT: %[[V23:.*]] = tensor.collapse_shape %[[V22]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
|
||||
// CHECK-NEXT: "SDFG.put"(%[[V5]], %[[V23]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V24:.*]] = tensor.extract_slice %[[Varg1]]{{\[}}%[[V21]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
|
||||
// CHECK-NEXT: %[[V25:.*]] = tensor.collapse_shape %[[V24]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
|
||||
// CHECK-NEXT: "SDFG.put"(%[[V6]], %[[V25]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V26:.*]] = "SDFG.get"(%[[V4]]) : (!SDFG.stream<tensor<513xi64>>) -> tensor<513xi64>
|
||||
// CHECK-NEXT: %[[V27:.*]] = tensor.insert_slice %[[V26]] into %[[V19]]{{\[}}%[[V21]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<513xi64> into tensor<4x513xi64>
|
||||
// CHECK-NEXT: %[[Vc2:.*]] = arith.constant 2 : index
|
||||
// CHECK-NEXT: %[[V28:.*]] = arith.muli %[[Vc1]], %[[Vc2]] : index
|
||||
// CHECK-NEXT: %[[V29:.*]] = arith.addi %[[Vc0]], %[[V28]] : index
|
||||
// CHECK-NEXT: %[[V30:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V29]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
|
||||
// CHECK-NEXT: %[[V31:.*]] = tensor.collapse_shape %[[V30]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
|
||||
// CHECK-NEXT: "SDFG.put"(%[[V8]], %[[V31]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V32:.*]] = tensor.extract_slice %[[Varg1]]{{\[}}%[[V29]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
|
||||
// CHECK-NEXT: %[[V33:.*]] = tensor.collapse_shape %[[V32]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
|
||||
// CHECK-NEXT: "SDFG.put"(%[[V9]], %[[V33]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V34:.*]] = "SDFG.get"(%[[V7]]) : (!SDFG.stream<tensor<513xi64>>) -> tensor<513xi64>
|
||||
// CHECK-NEXT: %[[V35:.*]] = tensor.insert_slice %[[V34]] into %[[V27]]{{\[}}%[[V29]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<513xi64> into tensor<4x513xi64>
|
||||
// CHECK-NEXT: %[[Vc3:.*]] = arith.constant 3 : index
|
||||
// CHECK-NEXT: %[[V36:.*]] = arith.muli %[[Vc1]], %[[Vc3]] : index
|
||||
// CHECK-NEXT: %[[V37:.*]] = arith.addi %[[Vc0]], %[[V36]] : index
|
||||
// CHECK-NEXT: %[[V38:.*]] = tensor.extract_slice %[[Varg0]]{{\[}}%[[V37]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
|
||||
// CHECK-NEXT: %[[V39:.*]] = tensor.collapse_shape %[[V38]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
|
||||
// CHECK-NEXT: "SDFG.put"(%[[V11]], %[[V39]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V40:.*]] = tensor.extract_slice %[[Varg1]]{{\[}}%[[V37]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<4x513xi64> to tensor<1x513xi64>
|
||||
// CHECK-NEXT: %[[V41:.*]] = tensor.collapse_shape %[[V40]] {{\[\[0, 1\]\]}} : tensor<1x513xi64> into tensor<513xi64>
|
||||
// CHECK-NEXT: "SDFG.put"(%[[V12]], %[[V41]]) : (!SDFG.stream<tensor<513xi64>>, tensor<513xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V42:.*]] = "SDFG.get"(%[[V10]]) : (!SDFG.stream<tensor<513xi64>>) -> tensor<513xi64>
|
||||
// CHECK-NEXT: %[[V43:.*]] = tensor.insert_slice %[[V42]] into %[[V35]]{{\[}}%[[V37]], 0{{\] \[1, 513\] \[1, 1\]}} : tensor<513xi64> into tensor<4x513xi64>
|
||||
// CHECK-NEXT: "SDFG.shutdown"(%[[V0]]) : (!SDFG.dfg) -> ()
|
||||
// CHECK-NEXT: return %[[V43]] : tensor<4x513xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.add_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
Reference in New Issue
Block a user