Andi Drebes
51b5b42c05
fix(compiler): Support indirect references to IVs in indexes when hoisting RT.await_future ops
...
Until now, the pass hoisting `RT.await_future` operations only
supports `tensor.parallel_insert_slice` operations that use loop
induction variables directly as indexes. Any more complex indexing
expressions produce a domination error, since a
`tensor.parallel_insert_slice` cloned by the pass into an additional
parallel for loop is left with references to values from the original
loop.
This change properly clones operations producing intermediate values
within the original parallel for loop and thus adds support for
indexing expressions that reference loops IVs only indirectly.
2024-04-25 11:05:54 +02:00
Andi Drebes
d620fa9a44
feat(compiler): Add pass hoisting RT.await_future out of scf.forall loops
...
The new pass hoists `RT.await_future` operations whose results are
yielded by scf.forall operations out of the loops in order to avoid
over-synchronization of data-flow tasks.
E.g., the following IR:
```
scf.forall (%arg) in (16)
shared_outs(%o1 = %sometensor, %o2 = %someothertensor)
-> (tensor<...>, tensor<...>)
{
...
%rph = "RT.build_return_ptr_placeholder"() :
() -> !RT.rtptr<!RT.future<tensor<...>>>
"RT.create_async_task"(..., %rph, ...) { ... } : ...
%future = "RT.deref_return_ptr_placeholder"(%rph) :
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
...
scf.forall.in_parallel {
...
tensor.parallel_insert_slice %res into %o1[..., %arg2, ...] [...] [...] :
tensor<...> into tensor<...>
...
}
}
```
is transformed into:
```
%tensoroffutures = tensor.empty() : tensor<16x!RT.future<tensor<...>>>
scf.forall (%arg) in (16)
shared_outs(%otfut = %tensoroffutures, %o2 = %someothertensor)
-> (tensor<...>, tensor<...>)
{
...
%rph = "RT.build_return_ptr_placeholder"() :
() -> !RT.rtptr<!RT.future<tensor<...>>>
"RT.create_async_task"(..., %rph, ...) { ... } : ...
%future = "RT.deref_return_ptr_placeholder"(%rph) :
(!RT.rtptr<!RT.future<...>>) -> !RT.future<tensor<...>>
%wrappedfuture = tensor.from_elements %future :
tensor<1x!RT.future<tensor<...>>>
...
scf.forall.in_parallel {
...
tensor.parallel_insert_slice %wrappedfuture into %otfut[%arg] [1] [1] :
tensor<1xRT.future<tensor<...>>> into tensor<16x!RT.future<tensor<...>>>
...
}
}
scf.forall (%arg) in (16) shared_outs(%o = %sometensor) -> (tensor<...>) {
%future = tensor.extract %tensoroffutures[%arg] :
tensor<4x!RT.future<tensor<...>>>
%res = "RT.await_future"(%future) : (!RT.future<tensor<...>>) -> tensor<...>
scf.forall.in_parallel {
tensor.parallel_insert_slice %res into %o[..., %arg, ...] [...] [...] :
tensor<...> into tensor<...>
}
}
```
2024-04-08 16:16:07 +02:00