mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
tk fa: use 16x64 tiles (#13086)
This commit is contained in:
@@ -10,7 +10,7 @@ constexpr int ATTN_N = 1024;
|
|||||||
constexpr int ATTN_H = 16;
|
constexpr int ATTN_H = 16;
|
||||||
constexpr int ATTN_D = 64;
|
constexpr int ATTN_D = 64;
|
||||||
|
|
||||||
template<int D> constexpr size_t ROWS = 16*(128/D); // height of each worker tile (rows)
|
template<int D> constexpr size_t ROWS = 16*(64/D); // height of each worker tile (rows)
|
||||||
template<int D, typename T=bf16, typename L=row_l> using qkvo_tile = rt<T, ROWS<D>, D, L>;
|
template<int D, typename T=bf16, typename L=row_l> using qkvo_tile = rt<T, ROWS<D>, D, L>;
|
||||||
template<int D, typename T=float> using attn_tile = rt<T, ROWS<D>, ROWS<D>>;
|
template<int D, typename T=float> using attn_tile = rt<T, ROWS<D>, ROWS<D>>;
|
||||||
template<int D> using shared_tile = st_bf<ROWS<D>, D>;
|
template<int D> using shared_tile = st_bf<ROWS<D>, D>;
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ if __name__ == "__main__":
|
|||||||
Tensor.realize(q, k, v, out)
|
Tensor.realize(q, k, v, out)
|
||||||
|
|
||||||
NUM_WORKERS = 4
|
NUM_WORKERS = 4
|
||||||
ROWS = 16 * (128 // D)
|
ROWS = 16 * (64 // D)
|
||||||
|
|
||||||
gsz = (N // (ROWS*NUM_WORKERS), H, B)
|
gsz = (N // (ROWS*NUM_WORKERS), H, B)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
|
|||||||
Reference in New Issue
Block a user