mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
tk: fa cleanups + causal test (#13963)
This commit is contained in:
@@ -340,7 +340,6 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
mask = Tensor(kernel.src[5])
|
||||
|
||||
delta_vec = (grad * attn).sum(-1).transpose(1, 2).unsqueeze(-2).detach()
|
||||
print(l_vec.shape, delta_vec.shape, grad.shape, attn.shape, grad_q.shape, grad_k.shape, grad_v.shape)
|
||||
|
||||
grad_q = Tensor.custom_kernel(grad_q, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_q)[0]
|
||||
grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2]
|
||||
|
||||
Reference in New Issue
Block a user