adding enable_gqa in SDPA (#11097)

Co-authored-by: wozeparrot <wozeparrot@gmail.com>
This commit is contained in:
Nino Risteski
2025-07-07 08:25:33 +02:00
committed by GitHub
parent b73e89110e
commit a1a146a499
2 changed files with 17 additions and 1 deletions

View File

@@ -2906,6 +2906,17 @@ class TestOps(unittest.TestCase):
lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True,attn_mask=m),
expected=RuntimeError)
def test_scaled_dot_product_attention_gqa(self):
helper_test_op([(32,32,16,64), (32,8,16,64), (32,8,16,64)],
lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,enable_gqa=True),
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True))
def test_scaled_dot_product_attention_gqa_errors(self):
self.helper_test_exception([(32,31,16,64), (32,8,16,64), (32,8,16,64)],
lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z),
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True),
expected=(AssertionError, RuntimeError, ValueError))
def test_binary_crossentropy(self):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)),
lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))