adding enable_gqa in SDPA (#11097)

Co-authored-by: wozeparrot <wozeparrot@gmail.com>
This commit is contained in:
Nino Risteski
2025-07-07 08:25:33 +02:00
committed by GitHub
parent b73e89110e
commit a1a146a499
2 changed files with 17 additions and 1 deletions

View File

@@ -2906,6 +2906,17 @@ class TestOps(unittest.TestCase):
lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True,attn_mask=m),
expected=RuntimeError)
def test_scaled_dot_product_attention_gqa(self):
helper_test_op([(32,32,16,64), (32,8,16,64), (32,8,16,64)],
lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,enable_gqa=True),
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True))
def test_scaled_dot_product_attention_gqa_errors(self):
self.helper_test_exception([(32,31,16,64), (32,8,16,64), (32,8,16,64)],
lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z),
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True),
expected=(AssertionError, RuntimeError, ValueError))
def test_binary_crossentropy(self):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)),
lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))

View File

@@ -3862,7 +3862,8 @@ class Tensor(MathTrait):
if num_classes == -1: num_classes = (self.max()+1).item()
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
@@ -3879,6 +3880,10 @@ class Tensor(MathTrait):
"""
# NOTE: it also works when `key` and `value` have symbolic shape.
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
if enable_gqa:
key = key.repeat_interleave(self.shape[-3] // key.shape[-3], dim=-3)
value = value.repeat_interleave(self.shape[-3] // value.shape[-3], dim=-3)
qk = self.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
# handle attention mask
if is_causal: