mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
move cast to before softmax in attention (#9213)
* move cast to before softmax in attention saved some memory because exp (which is used for backward) are done in half. training bert seems fine and can fit BS=78 now (from 66) * test
This commit is contained in:
20
test/unit/test_attention.py
Normal file
20
test/unit/test_attention.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, dtypes
|
||||
|
||||
# TODO: test_scheduler, but just in uint
|
||||
class TestAttention(unittest.TestCase):
|
||||
def test_half_qkv_buffers(self):
|
||||
BS, seqlen, dim = 10, 4, 100
|
||||
q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||||
k = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||||
v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize()
|
||||
attn = q.scaled_dot_product_attention(k, v)
|
||||
sched = attn.schedule()
|
||||
# attention has 5 kernels now
|
||||
self.assertEqual(len(sched), 5)
|
||||
softmax_inputs = sched[1:4]
|
||||
for si in softmax_inputs:
|
||||
assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -3603,7 +3603,7 @@ class Tensor(SimpleMathTrait):
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
||||
qk = qk + attn_mask
|
||||
return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
|
||||
return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value
|
||||
|
||||
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
|
||||
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
|
||||
|
||||
Reference in New Issue
Block a user