mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
adding enable_gqa in SDPA (#11097)
Co-authored-by: wozeparrot <wozeparrot@gmail.com>
This commit is contained in:
@@ -2906,6 +2906,17 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True,attn_mask=m),
|
||||
expected=RuntimeError)
|
||||
|
||||
def test_scaled_dot_product_attention_gqa(self):
|
||||
helper_test_op([(32,32,16,64), (32,8,16,64), (32,8,16,64)],
|
||||
lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,enable_gqa=True),
|
||||
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True))
|
||||
|
||||
def test_scaled_dot_product_attention_gqa_errors(self):
|
||||
self.helper_test_exception([(32,31,16,64), (32,8,16,64), (32,8,16,64)],
|
||||
lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z),
|
||||
lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,enable_gqa=True),
|
||||
expected=(AssertionError, RuntimeError, ValueError))
|
||||
|
||||
def test_binary_crossentropy(self):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),y.clip(0,1)),
|
||||
lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
|
||||
|
||||
@@ -3862,7 +3862,8 @@ class Tensor(MathTrait):
|
||||
if num_classes == -1: num_classes = (self.max()+1).item()
|
||||
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
|
||||
is_causal:bool=False, enable_gqa:bool=False) -> Tensor:
|
||||
"""
|
||||
Computes scaled dot-product attention.
|
||||
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
|
||||
@@ -3879,6 +3880,10 @@ class Tensor(MathTrait):
|
||||
"""
|
||||
# NOTE: it also works when `key` and `value` have symbolic shape.
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
# GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
if enable_gqa:
|
||||
key = key.repeat_interleave(self.shape[-3] // key.shape[-3], dim=-3)
|
||||
value = value.repeat_interleave(self.shape[-3] // value.shape[-3], dim=-3)
|
||||
qk = self.matmul(key.transpose(-2,-1), dtype=least_upper_dtype(self.dtype, key.dtype, dtypes.float32)) / math.sqrt(self.shape[-1])
|
||||
# handle attention mask
|
||||
if is_causal:
|
||||
|
||||
Reference in New Issue
Block a user