From beba490ba88baa504e73364eb604e1e25dde10e2 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 19 Jan 2025 15:19:23 -0500 Subject: [PATCH] update mask in scaled_dot_product_attention (#8674) built is_causal mask with ones_like and start with boolean, and reversed the mask -inf order --- test/test_schedule.py | 2 +- tinygrad/tensor.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 8ceb444625..357741a1fe 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1000,7 +1000,7 @@ class TestSchedule(unittest.TestCase): def test_scaled_dot_product_attention_causal_fusion(self): x, y, z = (Tensor.empty(32, 8, 16, 16) for _ in range(3)) out = Tensor.scaled_dot_product_attention(x, y, z, is_causal=True) - check_schedule(out, 6) + check_schedule(out, 5) def test_adam_step_fusion(self): with Tensor.train(): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d60f58665f..ab89775cf1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3565,8 +3565,7 @@ class Tensor(SimpleMathTrait): if num_classes == -1: num_classes = (self.max()+1).item() return self[..., None]._one_hot_along_dim(num_classes).where(1, 0) - def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, - dropout_p:float=0.0, is_causal:bool=False) -> Tensor: + def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: """ Computes scaled dot-product attention. `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor. @@ -3583,12 +3582,15 @@ class Tensor(SimpleMathTrait): """ # NOTE: it also works when `key` and `value` have symbolic shape. assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1]) + # handle attention mask if is_causal: if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True") - attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool) - if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0) - qk = self.matmul(key.transpose(-2,-1), acc_dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1]) - return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value + attn_mask = qk.ones_like(requires_grad=False, device=self.device, dtype=dtypes.bool).tril() + if attn_mask is not None: + if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf")) + qk = qk + attn_mask + return qk.softmax(-1).cast(self.dtype).dropout(dropout_p) @ value def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor: if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")