mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
tk fa: use 16x64 tiles (#13086)
This commit is contained in:
@@ -10,7 +10,7 @@ constexpr int ATTN_N = 1024;
|
||||
constexpr int ATTN_H = 16;
|
||||
constexpr int ATTN_D = 64;
|
||||
|
||||
template<int D> constexpr size_t ROWS = 16*(128/D); // height of each worker tile (rows)
|
||||
template<int D> constexpr size_t ROWS = 16*(64/D); // height of each worker tile (rows)
|
||||
template<int D, typename T=bf16, typename L=row_l> using qkvo_tile = rt<T, ROWS<D>, D, L>;
|
||||
template<int D, typename T=float> using attn_tile = rt<T, ROWS<D>, ROWS<D>>;
|
||||
template<int D> using shared_tile = st_bf<ROWS<D>, D>;
|
||||
|
||||
@@ -23,7 +23,7 @@ if __name__ == "__main__":
|
||||
Tensor.realize(q, k, v, out)
|
||||
|
||||
NUM_WORKERS = 4
|
||||
ROWS = 16 * (128 // D)
|
||||
ROWS = 16 * (64 // D)
|
||||
|
||||
gsz = (N // (ROWS*NUM_WORKERS), H, B)
|
||||
for _ in range(5):
|
||||
|
||||
Reference in New Issue
Block a user