Files
concrete/compilers/concrete-compiler/compiler/lib/Analysis
Andi Drebes d620fa9a44 feat(compiler): Add pass hoisting RT.await_future out of scf.forall loops
The new pass hoists `RT.await_future` operations whose results are
yielded by scf.forall operations out of the loops in order to avoid
over-synchronization of data-flow tasks.

E.g., the following IR:

```
scf.forall (%arg) in (16)
  shared_outs(%o1 = %sometensor, %o2 = %someothertensor)
  -> (tensor<...>, tensor<...>)
{
  ...
  %rph = "RT.build_return_ptr_placeholder"() :
    () -> !RT.rtptr<!RT.future<tensor<...>>>
  "RT.create_async_task"(..., %rph, ...) { ... } : ...
  %future = "RT.deref_return_ptr_placeholder"(%rph) :
    (!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
  %res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
  ...
  scf.forall.in_parallel {
    ...
    tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] :
      tensor<...> into tensor<...>
    ...
  }
}
```

is transformed into:

```
%tensoroffutures = tensor.empty() : tensor<16x!RT.future<tensor<...>>>

scf.forall (%arg) in (16)
  shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor)
  -> (tensor<...>, tensor<...>)
{
  ...
  %rph = "RT.build_return_ptr_placeholder"() :
    () -> !RT.rtptr<!RT.future<tensor<...>>>
  "RT.create_async_task"(..., %rph, ...) { ... } : ...
  %future = "RT.deref_return_ptr_placeholder"(%rph) :
    (!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
  %wrappedfuture = tensor.from_elements %future :
    tensor<1x!RT.future<tensor<...>>>
  ...
  scf.forall.in_parallel {
    ...
    tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] :
      tensor<1xRT.future<tensor<...>>> into tensor<16x!RT.future<tensor<...>>>
    ...
  }
}

scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) {
  %future = tensor.extract %tensoroffutures[%arg] :
    tensor<4x!RT.future<tensor<...>>>
  %res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
  scf.forall.in_parallel {
    tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] :
      tensor<...> into tensor<...>
  }
}
```
2024-04-08 16:16:07 +02:00
..