diff --git a/test/unit/test_attention.py b/test/unit/test_attention.py new file mode 100644 index 0000000000..9929dbc32d --- /dev/null +++ b/test/unit/test_attention.py @@ -0,0 +1,20 @@ +import unittest +from tinygrad import Tensor, dtypes + +# TODO: test_scheduler, but just in uint +class TestAttention(unittest.TestCase): + def test_half_qkv_buffers(self): + BS, seqlen, dim = 10, 4, 100 + q = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize() + k = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize() + v = Tensor.ones(BS, seqlen, dim, dtype=dtypes.half).contiguous().realize() + attn = q.scaled_dot_product_attention(k, v) + sched = attn.schedule() + # attention has 5 kernels now + self.assertEqual(len(sched), 5) + softmax_inputs = sched[1:4] + for si in softmax_inputs: + assert all(b.dtype == dtypes.half for b in si.bufs), f"non half {si.bufs=}" + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fa0853143e..bfe407bf17 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3603,7 +3603,7 @@ class Tensor(SimpleMathTrait): if attn_mask is not None: if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf")) qk = qk + attn_mask - return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value + return qk.cast(self.dtype).softmax(-1).dropout(dropout_p) @ value def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor: if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")