mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
fix(dfr-compiler): clone memref task arguments using identity maps for serialization.
This commit is contained in:
committed by
Quentin Bourgerie
parent
8cd3a3a599
commit
93802d128b
@@ -201,7 +201,6 @@ getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) {
|
||||
|
||||
static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
func::FuncOp workFunction) {
|
||||
DataLayout dataLayout = DataLayout::closest(DFTOp);
|
||||
Region &opBody = DFTOp->getParentOfType<func::FuncOp>().getBody();
|
||||
OpBuilder builder(DFTOp);
|
||||
|
||||
@@ -213,7 +212,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
if (!val.getType().isa<RT::FutureType>()) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Type futType = RT::FutureType::get(val.getType());
|
||||
Value memrefCloned, newval = val;
|
||||
Value memrefCloned;
|
||||
|
||||
// Find out if this value is needed in any other task
|
||||
SmallVector<Operation *, 2> taskOps;
|
||||
@@ -230,20 +229,6 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
// the memref in order to allow the deallocation pass which does
|
||||
// not synchronize with task execution.
|
||||
if (val.getType().isa<mlir::MemRefType>()) {
|
||||
// Get the type of memref that we will clone. In case this is
|
||||
// a subview, we discard the mapping so we copy to a contiguous
|
||||
// layout which pre-serializes this.
|
||||
MemRefType mrType = val.getType().dyn_cast<mlir::MemRefType>();
|
||||
if (!mrType.getLayout().isIdentity()) {
|
||||
unsigned rank = mrType.getRank();
|
||||
mrType = MemRefType::Builder(mrType)
|
||||
.setShape(mrType.getShape())
|
||||
.setLayout(AffineMapAttr::get(
|
||||
builder.getMultiDimIdentityMap(rank)));
|
||||
}
|
||||
newval = builder.create<mlir::memref::AllocOp>(val.getLoc(), mrType)
|
||||
.getResult();
|
||||
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
|
||||
memrefCloned = builder.create<arith::ConstantOp>(
|
||||
val.getLoc(), builder.getI64IntegerAttr(1));
|
||||
} else {
|
||||
@@ -252,7 +237,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
}
|
||||
|
||||
auto mrf = builder.create<RT::MakeReadyFutureOp>(val.getLoc(), futType,
|
||||
newval, memrefCloned);
|
||||
val, memrefCloned);
|
||||
replaceAllUsesInDFTsInRegionWith(val, mrf, opBody);
|
||||
}
|
||||
}
|
||||
@@ -356,6 +341,59 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
||||
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
|
||||
}
|
||||
|
||||
static void propagateMemRefLayoutInDFTs(RT::DataflowTaskOp op, Value val,
|
||||
Value newval) {
|
||||
for (auto &use : llvm::make_early_inc_range(val.getUses()))
|
||||
if (use.getOwner()->getParentOfType<RT::DataflowTaskOp>() != nullptr) {
|
||||
OpBuilder builder(use.getOwner());
|
||||
Value cast_newval = builder.create<mlir::memref::CastOp>(
|
||||
val.getLoc(), val.getType(), newval);
|
||||
use.set(cast_newval);
|
||||
}
|
||||
}
|
||||
|
||||
static void cloneMemRefTaskArgumentsWithIdentityMaps(RT::DataflowTaskOp op) {
|
||||
OpBuilder builder(op);
|
||||
for (Value val : op.getOperands()) {
|
||||
if (val.getType().isa<mlir::MemRefType>()) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
|
||||
// Find out if this memref is needed in any other task to clone
|
||||
// before all uses
|
||||
SmallVector<Operation *, 2> taskOps;
|
||||
for (auto &use : val.getUses())
|
||||
if (isa<RT::DataflowTaskOp>(use.getOwner()))
|
||||
taskOps.push_back(use.getOwner());
|
||||
Operation *first = op;
|
||||
for (auto op : taskOps)
|
||||
if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first))
|
||||
first = op;
|
||||
builder.setInsertionPoint(first);
|
||||
|
||||
// Get the type of memref that we will clone. In case this is
|
||||
// a subview, we discard the mapping so we copy to a contiguous
|
||||
// layout which pre-serializes this.
|
||||
MemRefType mrType_base = val.getType().dyn_cast<mlir::MemRefType>();
|
||||
MemRefType mrType = mrType_base;
|
||||
if (!mrType_base.getLayout().isIdentity()) {
|
||||
unsigned rank = mrType_base.getRank();
|
||||
mrType = MemRefType::Builder(mrType_base)
|
||||
.setShape(mrType_base.getShape())
|
||||
.setLayout(AffineMapAttr::get(
|
||||
builder.getMultiDimIdentityMap(rank)));
|
||||
}
|
||||
Value newval = builder.create<mlir::memref::AllocOp>(val.getLoc(), mrType)
|
||||
.getResult();
|
||||
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
|
||||
// Value cast_newval = builder.create<mlir::memref::CastOp>(val.getLoc(),
|
||||
// mrType_base, newval);
|
||||
replaceAllUsesInDFTsInRegionWith(
|
||||
val, newval, op->getParentOfType<func::FuncOp>().getBody());
|
||||
propagateMemRefLayoutInDFTs(op, val, newval);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// For documentation see Autopar.td
|
||||
struct LowerDataflowTasksPass
|
||||
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
|
||||
@@ -376,6 +414,7 @@ struct LowerDataflowTasksPass
|
||||
SmallVector<std::pair<RT::DataflowTaskOp, func::FuncOp>, 4> outliningMap;
|
||||
|
||||
func.walk([&](RT::DataflowTaskOp op) {
|
||||
cloneMemRefTaskArgumentsWithIdentityMaps(op);
|
||||
auto workFunctionName =
|
||||
Twine("_dfr_DFT_work_function__") +
|
||||
Twine(op->getParentOfType<func::FuncOp>().getName()) +
|
||||
|
||||
Reference in New Issue
Block a user