mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 16:11:26 -05:00
feat(compiler): Add pass hoisting RT.await_future out of scf.forall loops
The new pass hoists `RT.await_future` operations whose results are
yielded by scf.forall operations out of the loops in order to avoid
over-synchronization of data-flow tasks.
E.g., the following IR:
```
scf.forall (%arg) in (16)
shared_outs(%o1 = %sometensor, %o2 = %someothertensor)
-> (tensor<...>, tensor<...>)
{
...
%rph = "RT.build_return_ptr_placeholder"() :
() -> !RT.rtptr<!RT.future<tensor<...>>>
"RT.create_async_task"(..., %rph, ...) { ... } : ...
%future = "RT.deref_return_ptr_placeholder"(%rph) :
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
...
scf.forall.in_parallel {
...
tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] :
tensor<...> into tensor<...>
...
}
}
```
is transformed into:
```
%tensoroffutures = tensor.empty() : tensor<16x!RT.future<tensor<...>>>
scf.forall (%arg) in (16)
shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor)
-> (tensor<...>, tensor<...>)
{
...
%rph = "RT.build_return_ptr_placeholder"() :
() -> !RT.rtptr<!RT.future<tensor<...>>>
"RT.create_async_task"(..., %rph, ...) { ... } : ...
%future = "RT.deref_return_ptr_placeholder"(%rph) :
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
%wrappedfuture = tensor.from_elements %future :
tensor<1x!RT.future<tensor<...>>>
...
scf.forall.in_parallel {
...
tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] :
tensor<1xRT.future<tensor<...>>> into tensor<16x!RT.future<tensor<...>>>
...
}
}
scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) {
%future = tensor.extract %tensoroffutures[%arg] :
tensor<4x!RT.future<tensor<...>>>
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
scf.forall.in_parallel {
tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] :
tensor<...> into tensor<...>
}
}
```
This commit is contained in:
@@ -377,6 +377,11 @@ void TFHEKeyNormalizationPass::runOnOperation() {
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<mlir::tensor::DimOp>(
|
||||
target, typeConverter);
|
||||
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::tensor::ParallelInsertSliceOp>>(&getContext(), typeConverter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::tensor::ParallelInsertSliceOp>(target, typeConverter);
|
||||
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
conversion::TypeConverter>>(
|
||||
&getContext(), typeConverter);
|
||||
|
||||
Reference in New Issue
Block a user