update mask in scaled_dot_product_attention (#8674)

built is_causal mask with ones_like and start with boolean, and reversed the mask -inf order
This commit is contained in:
chenyu
2025-01-19 15:19:23 -05:00
committed by GitHub
parent 5842ee56c6
commit beba490ba8
2 changed files with 9 additions and 7 deletions

View File

@@ -1000,7 +1000,7 @@ class TestSchedule(unittest.TestCase):
def test_scaled_dot_product_attention_causal_fusion(self):
x, y, z = (Tensor.empty(32, 8, 16, 16) for _ in range(3))
out = Tensor.scaled_dot_product_attention(x, y, z, is_causal=True)
check_schedule(out, 6)
check_schedule(out, 5)
def test_adam_step_fusion(self):
with Tensor.train():