diff --git a/test/test_ops.py b/test/test_ops.py index ef1c1965cd..426c621651 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2906,6 +2906,17 @@ class TestOps(unittest.TestCase): lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True,attn_mask=m), expected=RuntimeError) + def test_scaled_dot_product_attention_gqa(self): + helper_test_op([(32,32,16,64), (32,8,16,64), (32,8,16,64)], + lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,enable_gqa=True), + lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True)) + + def test_scaled_dot_product_attention_gqa_errors(self): + self.helper_test_exception([(32,31,16,64), (32,8,16,64), (32,8,16,64)], + lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), + lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True), + expected=(AssertionError, RuntimeError, ValueError)) + def test_binary_crossentropy(self): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9d0378955b..d49dd21b45 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3862,7 +3862,8 @@ class Tensor(MathTrait): if num_classes == -1: num_classes = (self.max()+1).item() return self[..., None]._one_hot_along_dim(num_classes).where(1, 0) - def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: + def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, + is_causal:bool=False, enable_gqa:bool=False) -> Tensor: """ Computes scaled dot-product attention. `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor. @@ -3879,6 +3880,10 @@ class Tensor(MathTrait): """ # NOTE: it also works when `key` and `value` have symbolic shape. assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + # GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + if enable_gqa: + key = key.repeat_interleave(self.shape[-3] // key.shape[-3], dim=-3) + value = value.repeat_interleave(self.shape[-3] // value.shape[-3], dim=-3) qk = self.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1]) # handle attention mask if is_causal: