diff --git a/extra/thunder/cuda/fa.cu b/extra/thunder/cuda/fa.cu index 12d24cf41c..a18e2f155f 100644 --- a/extra/thunder/cuda/fa.cu +++ b/extra/thunder/cuda/fa.cu @@ -10,7 +10,7 @@ constexpr int ATTN_N = 1024; constexpr int ATTN_H = 16; constexpr int ATTN_D = 64; -template constexpr size_t ROWS = 16*(128/D); // height of each worker tile (rows) +template constexpr size_t ROWS = 16*(64/D); // height of each worker tile (rows) template using qkvo_tile = rt, D, L>; template using attn_tile = rt, ROWS>; template using shared_tile = st_bf, D>; diff --git a/extra/thunder/cuda/fa.py b/extra/thunder/cuda/fa.py index bfa95b080c..fd0c5bede7 100644 --- a/extra/thunder/cuda/fa.py +++ b/extra/thunder/cuda/fa.py @@ -23,7 +23,7 @@ if __name__ == "__main__": Tensor.realize(q, k, v, out) NUM_WORKERS = 4 - ROWS = 16 * (128 // D) + ROWS = 16 * (64 // D) gsz = (N // (ROWS*NUM_WORKERS), H, B) for _ in range(5):