mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-16 23:51:36 -05:00
The new pass hoists `RT.await_future` operations whose results are
yielded by scf.forall operations out of the loops in order to avoid
over-synchronization of data-flow tasks.
E.g., the following IR:
```
scf.forall (%arg) in (16)
shared_outs(%o1 = %sometensor, %o2 = %someothertensor)
-> (tensor<...>, tensor<...>)
{
...
%rph = "RT.build_return_ptr_placeholder"() :
() -> !RT.rtptr<!RT.future<tensor<...>>>
"RT.create_async_task"(..., %rph, ...) { ... } : ...
%future = "RT.deref_return_ptr_placeholder"(%rph) :
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
...
scf.forall.in_parallel {
...
tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] :
tensor<...> into tensor<...>
...
}
}
```
is transformed into:
```
%tensoroffutures = tensor.empty() : tensor<16x!RT.future<tensor<...>>>
scf.forall (%arg) in (16)
shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor)
-> (tensor<...>, tensor<...>)
{
...
%rph = "RT.build_return_ptr_placeholder"() :
() -> !RT.rtptr<!RT.future<tensor<...>>>
"RT.create_async_task"(..., %rph, ...) { ... } : ...
%future = "RT.deref_return_ptr_placeholder"(%rph) :
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
%wrappedfuture = tensor.from_elements %future :
tensor<1x!RT.future<tensor<...>>>
...
scf.forall.in_parallel {
...
tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] :
tensor<1xRT.future<tensor<...>>> into tensor<16x!RT.future<tensor<...>>>
...
}
}
scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) {
%future = tensor.extract %tensoroffutures[%arg] :
tensor<4x!RT.future<tensor<...>>>
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
scf.forall.in_parallel {
tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] :
tensor<...> into tensor<...>
}
}
```