mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
adding enable_gqa in SDPA (#11097)
Co-authored-by: wozeparrot <wozeparrot@gmail.com>
This commit is contained in:
@@ -2906,6 +2906,17 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True,attn_mask=m),
|
||||
expected=RuntimeError)
|
||||
|
||||
def test_scaled_dot_product_attention_gqa(self):
|
||||
helper_test_op([(32,32,16,64), (32,8,16,64), (32,8,16,64)],
|
||||
lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,enable_gqa=True),
|
||||
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True))
|
||||
|
||||
def test_scaled_dot_product_attention_gqa_errors(self):
|
||||
self.helper_test_exception([(32,31,16,64), (32,8,16,64), (32,8,16,64)],
|
||||
lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z),
|
||||
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True),
|
||||
expected=(AssertionError, RuntimeError, ValueError))
|
||||
|
||||
def test_binary_crossentropy(self):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)),
|
||||
lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
|
||||
|
||||
Reference in New Issue
Block a user