tk fa: use 16x64 tiles (#13086)

This commit is contained in:
wozeparrot
2025-11-03 18:25:38 -08:00
committed by GitHub
parent 4ed0f216b5
commit 9c00c0688a
2 changed files with 2 additions and 2 deletions

View File

@@ -10,7 +10,7 @@ constexpr int ATTN_N = 1024;
constexpr int ATTN_H = 16;
constexpr int ATTN_D = 64;
template<int D> constexpr size_t ROWS = 16*(128/D); // height of each worker tile (rows)
template<int D> constexpr size_t ROWS = 16*(64/D); // height of each worker tile (rows)
template<int D, typename T=bf16, typename L=row_l> using qkvo_tile = rt<T, ROWS<D>, D, L>;
template<int D, typename T=float> using attn_tile = rt<T, ROWS<D>, ROWS<D>>;
template<int D> using shared_tile = st_bf<ROWS<D>, D>;

View File

@@ -23,7 +23,7 @@ if __name__ == "__main__":
Tensor.realize(q, k, v, out)
NUM_WORKERS = 4
ROWS = 16 * (128 // D)
ROWS = 16 * (64 // D)
gsz = (N // (ROWS*NUM_WORKERS), H, B)
for _ in range(5):