move cast to before softmax in attention (#9213)

* move cast to before softmax in attention

saved some memory because exp (which is used for backward) are done in half. training bert seems fine and can fit BS=78 now (from 66)

* test
This commit is contained in:
chenyu
2025-02-24 17:24:59 -05:00
committed by GitHub
parent f0b24d230c
commit 90c3ed17c5
2 changed files with 21 additions and 1 deletions

View File

@@ -0,0 +1,20 @@
import unittest
from tinygrad import Tensor, dtypes
# TODO: test_scheduler, but just in uint
class TestAttention(unittest.TestCase):
def test_half_qkv_buffers(self):
BS, seqlen, dim = 10, 4, 100
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
k = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
attn = q.scaled_dot_product_attention(k, v)
sched = attn.schedule()
# attention has 5 kernels now
self.assertEqual(len(sched), 5)
softmax_inputs = sched[1:4]
for si in softmax_inputs:
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}"
if __name__ == '__main__':
unittest.main()

View File

@@ -3603,7 +3603,7 @@ class Tensor(SimpleMathTrait):
if attn_mask is not None:
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
qk = qk + attn_mask
return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")