[BACKEND] Fix scan issues on repetitive warps and improve perf when there's a single warp on the axis (#2330)

1. On the axis, using `getAxisNumWarpsWithUniqueData` instead of getting
the raw number of warps to avoid communication among warps that handle
the same piece of data.
2. When there's a single warp on the axis, using warp Intrinsics for
communication and skip shared memory.

Need a follow up PR for code clean up.
This commit is contained in:
Keren Zhou
2023-09-18 17:45:05 -04:00
committed by GitHub
parent a9ae9886dc
commit 307b5caa49
6 changed files with 195 additions and 24 deletions

View File

@@ -88,6 +88,8 @@ public:
unsigned getNonAxisNumThreadsPerCTA();
// Return the number of warps per CTA along axis dim.
unsigned getAxisNumWarps();
// Return the number of warps per CTA along axis dim with unique data.
unsigned getAxisNumWarpsWithUniqueData();
// Return the number of threads per warp along axis dim.
unsigned getAxisNumThreadsPerWarp();
// Return the number of blocks along axis dim.