Files
ROCm/python
Lixun Zhang 821e75a2b0 Improve FA fwd kernel with causal=True (#356)
* Attempt to absorb upstream's changes to improve causal=True

* Add autotuner

* Optimize for AMD MI250

- add pre_load_v as a tuning parameter
- do not define N_CTX as constexpr
- perform the second dot before sum
- remove qk_scale out of the inner loop
- add more configs in the autotuner

Note that bwd kernel is disabled for now. This is because we enabled
autotuning and grid becomes a function. So ctx.grid[0] no longer works.

* Enable bwd kernel
2023-10-12 12:34:27 -05:00
..
2023-10-03 05:34:44 +00:00
2023-09-18 15:30:06 -05:00